-
Notifications
You must be signed in to change notification settings - Fork 24
/
gelu.py
130 lines (116 loc) · 4.54 KB
/
gelu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from utils import size
from typing import List, Tuple
from hardware_model.device import Device
from software_model.operators import Operator
from software_model.utils import Tensor, DataType
from math import ceil, log2, log
import time
import statistics
import numpy as np
import torch
@torch.compile
def gelu_gpu(input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(input, approximate="tanh")
# x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
class GeLU(Operator):
def __init__(self, data_type: DataType):
super().__init__(0, 0, 0, 0, data_type)
self.shape = None
def __call__(self, input: Tensor) -> Tensor:
assert self.data_type == input.data_type
self.shape = input.shape
self.M = size(input.shape[:])
self.computational_graph = self.ComputationalGraph(self.M, self.data_type)
return input
def roofline_model(self, pcb_module: Device):
self.computational_graph.data_type = (
pcb_module.compute_module.core.vector_unit.data_type
)
M = self.M
data_type = self.computational_graph.data_type
total_io_count = M * 2 * data_type.word_size
io_latency = (
total_io_count / min(pcb_module.io_module.bandwidth
, pcb_module.compute_module.l2_bandwidth_per_cycle
* pcb_module.compute_module.clock_freq)
)
total_flop_count = M * (
10 + pcb_module.compute_module.core.vector_unit.flops_per_exp
)
compute_latency = (
total_flop_count
/ pcb_module.compute_module.core.vector_unit.total_vector_flops_per_cycle
/ pcb_module.compute_module.core_count
/ pcb_module.compute_module.clock_freq
)
self.roofline_latency = max(compute_latency, io_latency)
return self.roofline_latency
def print_latency(self):
print(f"{self.shape}, {self.latency_on_gpu*1e6}us")
class ComputationalGraph:
def __init__(self, M: int, data_type: DataType):
self.M = M
self.data_type = data_type
def compile_and_simulate(self, pcb_module: Device, compile_mode: str):
self.computational_graph.data_type = (
pcb_module.compute_module.core.vector_unit.data_type
)
parallelism = (
pcb_module.compute_module.core_count
* pcb_module.compute_module.core.vector_unit.vector_width
* pcb_module.compute_module.core.vector_unit.vector_count
)
M = ceil(self.computational_graph.M / parallelism) * parallelism
data_type = self.computational_graph.data_type
total_io_count = M * 2 * data_type.word_size
io_latency = (
total_io_count / pcb_module.io_module.bandwidth
+ total_io_count
/ pcb_module.compute_module.l2_bandwidth_per_cycle
/ pcb_module.compute_module.clock_freq
)
total_flop_count = M * (
10 + pcb_module.compute_module.core.vector_unit.flops_per_exp
)
compute_latency = (
total_flop_count
/ pcb_module.compute_module.core.vector_unit.total_vector_flops_per_cycle
/ pcb_module.compute_module.core_count
/ pcb_module.compute_module.clock_freq
)
return max(compute_latency, io_latency)
def run_on_gpu(self):
assert self.shape is not None
input = torch.randn(self.shape, dtype=torch.float16, device="cuda")
latencies = []
# warmup
for _ in range(3):
_ = gelu_gpu(input)
torch.cuda.synchronize()
for _ in range(self.iterations):
start = time.time()
output = gelu_gpu(input)
torch.cuda.synchronize()
end = time.time()
assert output.shape == input.shape
latencies.append(end - start)
# print(latencies)
self.latency_on_gpu = statistics.median(latencies)
return self.latency_on_gpu
@staticmethod
def gpu_kernel_launch_overhead():
import torch
size = 1
latencies = []
for _ in range(50):
a = torch.randn(size, size, device="cuda")
torch.cuda.synchronize()
start = time.time()
c = gelu_gpu(a)
torch.cuda.synchronize()
end = time.time()
latencies.append(end - start)
avg_overhead = statistics.median(latencies)
# print('GPU kernel launch overhead: ', avg_overhead*1e3, 'ms')
print(latencies)
return avg_overhead