-
Notifications
You must be signed in to change notification settings - Fork 231
/
main_benchmark.py
125 lines (103 loc) · 3.61 KB
/
main_benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import time
from typing import Optional
import torch
from torch.cuda.amp import autocast
from cvnets import get_model
from engine.utils import autocast_fn
from options.opts import get_benchmarking_arguments
from utils import logger
from utils.common_utils import device_setup
from utils.pytorch_to_coreml import convert_pytorch_to_coreml
from utils.tensor_utils import create_rand_tensor
def cpu_timestamp(*args, **kwargs):
# perf_counter returns time in seconds
return time.perf_counter()
def cuda_timestamp(cuda_sync=False, device=None, *args, **kwargs):
if cuda_sync:
torch.cuda.synchronize(device=device)
# perf_counter returns time in seconds
return time.perf_counter()
def step(
time_fn,
model,
example_inputs,
autocast_enable: False,
amp_precision: Optional[str] = "float16",
):
start_time = time_fn()
with autocast_fn(enabled=autocast_enable, amp_precision=amp_precision):
model(example_inputs)
end_time = time_fn(cuda_sync=True)
return end_time - start_time
def main_benchmark():
# set-up
opts = get_benchmarking_arguments()
# device set-up
opts = device_setup(opts)
norm_layer = getattr(opts, "model.normalization.name", "batch_norm")
if norm_layer.find("sync") > -1:
norm_layer = norm_layer.replace("sync_", "")
setattr(opts, "model.normalization.name", norm_layer)
device = getattr(opts, "dev.device", torch.device("cpu"))
if torch.cuda.device_count() == 0:
device = torch.device("cpu")
time_fn = cpu_timestamp if device == torch.device("cpu") else cuda_timestamp
warmup_iterations = getattr(opts, "benchmark.warmup_iter", 10)
iterations = getattr(opts, "benchmark.n_iter", 50)
batch_size = getattr(opts, "benchmark.batch_size", 1)
mixed_precision = (
False
if device == torch.device("cpu")
else getattr(opts, "common.mixed_precision", False)
)
mixed_precision_dtype = getattr(opts, "common.mixed_precision_dtype", "float16")
# load the model
model = get_model(opts)
model.eval()
# print model information
model.info()
example_inp = create_rand_tensor(opts=opts, device="cpu", batch_size=batch_size)
# cool down for 5 seconds
time.sleep(5)
if getattr(opts, "benchmark.use_jit_model", False):
converted_models_dict = convert_pytorch_to_coreml(
opts=None,
pytorch_model=model,
input_tensor=example_inp,
jit_model_only=True,
)
model = converted_models_dict["jit"]
model = model.to(device=device)
example_inp = example_inp.to(device=device)
model.eval()
with torch.no_grad():
# warm-up
for i in range(warmup_iterations):
step(
time_fn=time_fn,
model=model,
example_inputs=example_inp,
autocast_enable=mixed_precision,
amp_precision=mixed_precision_dtype,
)
n_steps = n_samples = 0.0
# run benchmark
for i in range(iterations):
step_time = step(
time_fn=time_fn,
model=model,
example_inputs=example_inp,
autocast_enable=mixed_precision,
amp_precision=mixed_precision_dtype,
)
n_steps += step_time
n_samples += batch_size
logger.info(
"Number of samples processed per second: {:.3f}".format(n_samples / n_steps)
)
if __name__ == "__main__":
main_benchmark()