-
Notifications
You must be signed in to change notification settings - Fork 38
/
utils.py
227 lines (205 loc) · 8.76 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import platform
import sys
import shlex
import time
import torch
import numpy as np
from pathlib import Path
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax, SoftmaxFusionTypes
from megatron.model.transformer import ParallelSelfAttention, ParallelMLP, ParallelTransformerLayer
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.activations import bias_gelu_impl
from megatron.model.gpt2_model import gpt2_attention_mask_func as attention_mask_func
from megatron.model.word_embeddings import Embedding
def print_benchmark_header(notes="None"):
print(f"""
Benchmark started on {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}
** Command line:
{sys.executable} {" ".join(map(shlex.quote, sys.argv))}
** Platform:
{" ".join(platform.uname())}
{torch.cuda.get_device_properties(torch.device('cuda'))}
** Critical component versions:
torch={torch.__version__}, cuda={torch.version.cuda}, nccl={torch.cuda.nccl.version()}
** Additional notes:
{notes}
{"-" * 80}
""")
class Tee(object):
def __init__(self, filename, verbose):
Path(filename).resolve().parent.mkdir(parents=True, exist_ok=True)
self.file = open(filename, "w")
self.verbose = verbose
if self.verbose:
self.stdout = sys.stdout
def write(self, message):
self.file.write(message)
if self.verbose:
self.stdout.write(message)
def flush(self):
self.file.flush()
if self.verbose:
self.stdout.flush()
def display(shape):
return "x".join([str(dim) for dim in shape])
# Benchmark of a basic GEMM
def benchmark_mm(m, n, k, num_iterations, num_warmup_iterations):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
A = torch.randn(m, n).half().to("cuda")
B = torch.randn(n, k).half().to("cuda")
C = torch.empty(m, k).half().to("cuda")
times = np.zeros(num_iterations+num_warmup_iterations)
for i in range(num_warmup_iterations + num_iterations):
with torch.no_grad():
start.record()
torch.mm(A, B, out=C)
end.record()
torch.cuda.synchronize()
times[i] = start.elapsed_time(end)
times = times[num_warmup_iterations:]
elapsed_time = np.amin(times)/1000
print(f"Elapsed time for {m}x{n}x{k}: {elapsed_time:.3f}")
print(f"Throughput (in TFLOP/s) for {m}x{n}x{k}: {(2 * m * n * k) / (elapsed_time * 10**12):.3f}")
print("-" * 80)
return elapsed_time
# Benchmark of a GEMM with a single batched operator
def benchmark_mm_b(m, n, k, label, b, num_iterations,num_warmup_iterations):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
B = torch.randn((k, n)).half().to("cuda")
if b is None:
A = torch.randn((m, n)).half().to("cuda")
C = torch.empty((m, k)).half().to("cuda")
b = 1
else:
A = torch.randn((b, m, n)).half().to("cuda")
C = torch.empty((b, m, k)).half().to("cuda")
times = np.zeros(num_iterations+num_warmup_iterations)
for i in range(num_warmup_iterations + num_iterations):
with torch.no_grad():
start.record()
torch.nn.functional.linear(A, B, out=C)
end.record()
torch.cuda.synchronize()
times[i] = start.elapsed_time(end)
times = times[num_warmup_iterations:]
elapsed_time = np.amin(times)/1000
print(f"Elapsed time for {label} ({m}x{n}x{k}, b={b}): {elapsed_time :.4f}")
print(f"Throughput (in TFLOP/s) for {label} ({m}x{n}x{k}, b={b}): "
f"{(2 * b * m * n * k) / (elapsed_time * 10**12):.3f}")
return elapsed_time
def benchmark_bmm(b, m, n, k, label,num_iterations, num_warmup_iterations):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
A = torch.randn((b, m, n)).half().to("cuda")
B = torch.randn((b, n, k)).half().to("cuda")
C = torch.empty((b, m, k)).half().to("cuda")
times = np.zeros(num_iterations+num_warmup_iterations)
for i in range(num_warmup_iterations + num_iterations):
with torch.no_grad():
start.record()
torch.bmm(A, B, out=C)
end.record()
torch.cuda.synchronize()
times[i] = start.elapsed_time(end)
times = times[num_warmup_iterations:]
elapsed_time = np.amin(times)/1000
print(f"Elapsed time for {label} ({b}x{m}x{n}x{k}): {elapsed_time :.4f}")
print(f"Throughput (in TFLOP/s) for {label} ({b}x{m}x{n}x{k}): "
f"{(2 * b * m * n * k) / (elapsed_time * 10**12):.3f}")
return elapsed_time
def benchmark_dropout(A_dim, label, num_iterations, num_warmup_iterations):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
A = torch.randn(A_dim).half().to("cuda")
dropout = torch.nn.Dropout(0.5).to("cuda")
times = np.zeros(num_iterations+num_warmup_iterations)
for i in range(num_warmup_iterations + num_iterations):
with torch.no_grad():
start.record()
dropout(A)
end.record()
torch.cuda.synchronize()
times[i] = start.elapsed_time(end)
times = times[num_warmup_iterations:]
elapsed_time = np.amin(times)/1000
print(f"Elapsed time for {label} ({display(A_dim)}): {elapsed_time :.4f}")
return elapsed_time
def benchmark_softmax(scores_shape, seq_length, label, num_iterations,num_warmup_iterations):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
scores = torch.randn(scores_shape).half().to("cuda")
attention_mask = torch.tril(torch.ones(
(1, seq_length, seq_length), device="cuda")).view(
1, 1, seq_length, seq_length)
attention_mask = attention_mask < 0.5
softmax = FusedScaleMaskSoftmax(
True, False,
SoftmaxFusionTypes.none, #attentionmasktype.padding=1,True
attention_mask_func, True, 1)
times = np.zeros(num_iterations+num_warmup_iterations)
for i in range(num_warmup_iterations + num_iterations):
with torch.no_grad():
start.record()
softmax(scores, attention_mask)
end.record()
torch.cuda.synchronize()
times[i] = start.elapsed_time(end)
times = times[num_warmup_iterations:]
elapsed_time = np.amin(times)/1000
print(f"Elapsed time for {label} ({display(scores_shape)}): {elapsed_time :.4f}")
return elapsed_time
def benchmark_fused_gelu(A_dim, b_dim, label, num_iterations, num_warmup_iterations):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
A = torch.randn(A_dim).half().to("cuda")
b = torch.randn(b_dim).half().to("cuda")
times = np.zeros(num_iterations+num_warmup_iterations)
for i in range(num_warmup_iterations + num_iterations):
with torch.no_grad():
start.record()
bias_gelu_impl(A, b)
end.record()
torch.cuda.synchronize()
times[i] = start.elapsed_time(end)
times = times[num_warmup_iterations:]
elapsed_time = np.amin(times)/1000
print(f"Elapsed time for {label} ({display(A_dim)}): {elapsed_time :.4f}")
return elapsed_time
def benchmark_layer_norm(A_dim, normalized_shape, label, num_iterations, num_warmup_iterations):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
A = torch.randn(A_dim).half().to("cuda")
layer_norm = LayerNorm(normalized_shape).half().to("cuda")
times = np.zeros(num_iterations+num_warmup_iterations)
for i in range(num_warmup_iterations + num_iterations):
with torch.no_grad():
start.record()
layer_norm(A)
end.record()
torch.cuda.synchronize()
times[i] = start.elapsed_time(end)
times = times[num_warmup_iterations:]
elapsed_time = np.amin(times)/1000
print(f"Elapsed time for {label} ({display(A_dim)}): {elapsed_time :.4f}")
return elapsed_time
def benchmark_add_bias_dropout(shape, label, num_iterations, num_warmup_iterations):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
A = torch.randn(shape).half().to("cuda")
bias = torch.randn(shape).half().to("cuda")
residue = torch.randn(shape).half().to("cuda")
times = np.zeros(num_iterations+num_warmup_iterations)
for i in range(num_warmup_iterations + num_iterations):
with torch.no_grad():
start.record()
bias_dropout_add_fused_train(A, bias, residue, 0.0)
end.record()
torch.cuda.synchronize()
times[i] = start.elapsed_time(end)
times = times[num_warmup_iterations:]
elapsed_time = np.amin(times)/1000
print(f"Elapsed time for {label} ({display(shape)}): {elapsed_time :.4f}")
return elapsed_time