-
Notifications
You must be signed in to change notification settings - Fork 24
/
layernorm.py
372 lines (342 loc) · 12.2 KB
/
layernorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
from utils import size
from typing import List, Tuple
from hardware_model.device import Device
from software_model.operators import Operator
from software_model.utils import Tensor, DataType
from math import ceil, log2, log
import time
import statistics
import numpy as np
import torch
@torch.compile
def layernorm_gpu(input: torch.Tensor) -> torch.Tensor:
return torch.layer_norm(input, [input.shape[-1]])
class LayerNorm(Operator):
def __init__(self, data_type: DataType):
super().__init__(0, 0, 0, 0, data_type)
self.shape = None
def __call__(self, input: Tensor) -> Tensor:
assert self.data_type == input.data_type
self.shape = input.shape
self.M = size(input.shape[:-1])
self.N = input.shape[-1]
self.computational_graph = self.ComputationalGraph(
self.M, self.N, self.data_type
)
return input
def roofline_model(self, pcb_module: Device):
self.io_count = self.M * self.N * self.data_type.word_size * 2
self.flop_count = self.M * self.N * 7
self.roofline_latency = max(
self.io_count
/ min(
pcb_module.io_module.bandwidth,
pcb_module.compute_module.l2_bandwidth_per_cycle
* pcb_module.compute_module.clock_freq,
),
self.flop_count / pcb_module.compute_module.total_vector_flops,
)
return self.roofline_latency
def print_latency(self):
print(f"{self.shape}, {self.latency_on_gpu*1e6}us")
class ComputationalGraph:
def __init__(self, M: int, N: int, data_type: DataType):
self.M = M
self.N = N
self.data_type = data_type
class Mapping:
def __init__(
self,
l2_tile_M: int,
l2_tile_N: int,
l1_tile_M: int,
l1_tile_N: int,
):
self.l2_tile_M = l2_tile_M
self.l2_tile_N = l2_tile_N
self.l1_tile_M = l1_tile_M
self.l1_tile_N = l1_tile_N
def display(self):
print("-" * 20)
print(
f"l2_tile_M: {self.l2_tile_M}, l1_tile_M: {self.l1_tile_M}, l1_tile_N: {self.l1_tile_N}"
)
def compile_and_simulate(self, pcb_module: Device, compile_mode: str):
self.computational_graph.data_type = (
pcb_module.compute_module.core.vector_unit.data_type
)
min_cycle_count = float("inf")
best_mapping = None
M = self.computational_graph.M
N = self.computational_graph.N
data_type = self.computational_graph.data_type
l2_tile_N = N
l2_tile_M = (
pcb_module.compute_module.l2_size // (l2_tile_N * data_type.word_size) // 2
)
l2_tile_M = min(l2_tile_M, M)
if compile_mode == "heuristic-GPU" or compile_mode == "heuristic-our-throughput":
# if N <= 1024:
l1_tile_N = N
l1_tile_M = (
pcb_module.compute_module.core.SRAM_size
// (l1_tile_N * data_type.word_size)
// 2
)
while l1_tile_M < pcb_module.compute_module.core.vector_unit.vector_count:
l1_tile_N = l1_tile_N // 2
l1_tile_M = (
pcb_module.compute_module.core.SRAM_size
// (l1_tile_N * data_type.word_size)
// 2
)
l1_tile_M = min(l1_tile_M, l2_tile_M)
elif compile_mode == "heuristic-TPU":
l1_tile_N = N
l1_tile_M = pcb_module.compute_module.core.SRAM_size // (
2 * l1_tile_N * data_type.word_size
)
l1_tile_M = min(l1_tile_M, M)
mapping = self.Mapping(
l2_tile_M,
l2_tile_N,
l1_tile_M,
l1_tile_N,
)
cycle_count = self.simulate(self.computational_graph, mapping, pcb_module)
if cycle_count < min_cycle_count:
min_cycle_count = cycle_count
best_mapping = mapping
self.best_mapping = best_mapping
self.best_cycle_count = min_cycle_count
self.best_latency = min_cycle_count / pcb_module.compute_module.clock_freq
self.latency = self.best_latency
# self.best_mapping.display()
return self.latency
def simulate(
self,
computational_graph: ComputationalGraph,
mapping: Mapping,
pcb_module: Device,
) -> int:
M = computational_graph.M
N = computational_graph.N
data_type = computational_graph.data_type
l2_tile_M = mapping.l2_tile_M
M_l2_t = M // l2_tile_M
M_remain = M % l2_tile_M
l2_tiles = np.empty([ceil(M / l2_tile_M)], dtype=self.L2TileSimulator)
if M_l2_t != 0:
l2_tiles[:M_l2_t] = self.L2TileSimulator(
l2_tile_M,
N,
data_type,
mapping,
pcb_module,
)
if M_remain != 0:
l2_tiles[-1] = self.L2TileSimulator(
M_remain,
N,
data_type,
mapping,
pcb_module,
)
total_cycle_count = 0
l2_tile_count = ceil(M / l2_tile_M)
for m in range(l2_tile_count):
total_cycle_count += l2_tiles[m].read_cycle_count
total_cycle_count += l2_tiles[m].compute_cycle_count
total_cycle_count += l2_tiles[m].write_cycle_count
return total_cycle_count
class L2TileSimulator:
def __init__(
self,
M: int,
N: int,
data_type: DataType,
mapping: "LayerNorm.Mapping",
pcb_module: Device,
):
self.M = M
self.N = N
self.read_cycle_count = self.simulate_l2_tile_io_cycle_count(
M, N, data_type, pcb_module
)
self.write_cycle_count = self.simulate_l2_tile_io_cycle_count(
M, N, data_type, pcb_module
)
self.compute_cycle_count = self.simulate_l2_tile_compute_cycle_count(
M, N, data_type, mapping, pcb_module
)
def simulate_l2_tile_io_cycle_count(
self, M: int, N: int, data_type: DataType, chiplet_module: Device
):
return ceil(
M
* N
* data_type.word_size
/ (
chiplet_module.io_module.bandwidth
/ chiplet_module.compute_module.clock_freq
)
)
def simulate_l2_tile_compute_cycle_count(
self,
M: int,
N: int,
data_type: DataType,
mapping: "LayerNorm.Mapping",
pcb_module: Device,
):
l1_tile_M = mapping.l1_tile_M
l1_tile_N = mapping.l1_tile_N
l1_tile = LayerNorm.L1TileSimulator(
l1_tile_M,
l1_tile_N,
data_type,
mapping,
pcb_module,
)
l1_tile_count = ceil(M / l1_tile_M) * ceil(N / l1_tile_N)
l1_tile_cycle_count = (
l1_tile.read_cycle_count * 3
+ l1_tile.write_cycle_count
+ l1_tile.compute_cycle_count
)
total_cycle_count = (
ceil(l1_tile_count / pcb_module.compute_module.core_count)
) * (
l1_tile_cycle_count
+ (ceil(N / l1_tile_N) - 1) * (l1_tile.reduction_cycle_count)
)
return total_cycle_count
class L1TileSimulator:
def __init__(
self,
M: int,
N: int,
data_type: DataType,
mapping: "LayerNorm.Mapping",
pcb_module: Device,
):
self.M = M
self.N = N
self.read_cycle_count = self.simulate_l1_tile_io_cycle_count(
M, N, data_type, pcb_module
)
self.compute_cycle_count = self.simulate_l1_tile_compute_cycle_count(
M, N, data_type, mapping, pcb_module
)
self.write_cycle_count = self.simulate_l1_tile_io_cycle_count(
M, N, data_type, pcb_module
)
self.reduction_cycle_count = (
M
* N
/ pcb_module.compute_module.core.vector_unit.total_vector_flops_per_cycle
+ M
* N
* data_type.word_size
* 2
/ (
pcb_module.compute_module.l2_bandwidth_per_cycle
/ pcb_module.compute_module.core_count
)
)
def simulate_l1_tile_io_cycle_count(
self, M: int, N: int, data_type: DataType, pcb_module: Device
):
return ceil(
M
* N
* data_type.word_size
/ (pcb_module.compute_module.l2_bandwidth_per_cycle)
)
def simulate_l1_tile_compute_cycle_count(
self,
M: int,
N: int,
data_type: DataType,
mapping: "LayerNorm.Mapping",
pcb_module: Device,
):
M_per_vector_count = ceil(
M / pcb_module.compute_module.core.vector_unit.vector_count
)
N_per_vector_count = N
M_per_vector_lane = M_per_vector_count
N_per_vector_lane = ceil(
N_per_vector_count
/ pcb_module.compute_module.core.vector_unit.vector_width
)
# each lane computes it own mean
total_cycle_count = ceil(
N_per_vector_lane
* M_per_vector_lane
/ pcb_module.compute_module.core.vector_unit.flops_per_cycle
)
# the whole vector reduce to one mean
total_cycle_count += log2(
pcb_module.compute_module.core.vector_unit.vector_width
)
# each lane computes it own variance
total_cycle_count += (
ceil(
N_per_vector_lane
* M_per_vector_lane
/ pcb_module.compute_module.core.vector_unit.flops_per_cycle
)
* 2
)
# the whole vector reduce to one variance
total_cycle_count += log2(
pcb_module.compute_module.core.vector_unit.vector_width
)
# calculate normalized output
total_cycle_count += (
ceil(
N_per_vector_lane
* M_per_vector_lane
/ pcb_module.compute_module.core.vector_unit.flops_per_cycle
)
* 4
) # division is heavy
return total_cycle_count
def run_on_gpu(self):
# import torch
# from apex.normalization.fused_layer_norm import FusedLayerNorm
# from apex.contrib.layer_norm import FastLayerNorm
assert self.shape is not None
input = torch.randn(self.shape, dtype=torch.float16, device="cuda")
latencies = []
# warmup
for _ in range(3):
_ = layernorm_gpu(input)
torch.cuda.synchronize()
for _ in range(self.iterations):
start = time.time()
output = layernorm_gpu(input)
torch.cuda.synchronize()
end = time.time()
assert output.shape == input.shape
latencies.append(end - start)
# print(latencies)
self.latency_on_gpu = statistics.median(latencies)
return self.latency_on_gpu
@staticmethod
def gpu_kernel_launch_overhead():
import torch
size = 1
latencies = []
a = torch.randn(1, 1, 1, device="cuda")
for _ in range(50):
start = time.time()
c = layernorm_gpu(a)
torch.cuda.synchronize()
end = time.time()
latencies.append(end - start)
avg_overhead = statistics.median(latencies)
# print('GPU kernel launch overhead: ', avg_overhead*1e3, 'ms')
print(latencies)
return avg_overhead