-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathonnx_export.py
executable file
·89 lines (72 loc) · 2.63 KB
/
onnx_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import argparse
import torch
from models import RetinaFace
from config import get_config
def parse_arguments():
parser = argparse.ArgumentParser(description='ONNX Export')
parser.add_argument(
'-w', '--weights',
default='./weights/last.pth',
type=str,
help='Trained state_dict file path to open'
)
parser.add_argument(
'-n', '--network',
type=str,
default='mobilenetv1',
choices=[
'mobilenetv1', 'mobilenetv1_0.25', 'mobilenetv1_0.50',
'mobilenetv2', 'resnet50', 'resnet34', 'resnet18'
],
help='Backbone network architecture to use'
)
return parser.parse_args()
@torch.no_grad()
def onnx_export(params):
# Get model configuration
cfg = get_config(params.network)
if cfg is None:
raise KeyError(f"Config file for {params.network} not found!")
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = RetinaFace(cfg=cfg)
model.to(device)
# Load weights
state_dict = torch.load(params.weights, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
print("Model loaded successfully!")
# Set model to evaluation mode
model.eval()
# Generate output filename
fname = os.path.splitext(os.path.basename(params.weights))[0]
onnx_model = f'{fname}.onnx'
print(f"==> Exporting model to ONNX format at '{onnx_model}'")
# Create dummy input (batch_size=1, channels=3, height=640, width=640)
x = torch.randn(1, 3, 640, 640).to(device)
# Export model to ONNX
torch.onnx.export(
model, # PyTorch Model
x, # Model input
onnx_model, # Output file path
export_params=True, # Store the trained parameter weights inside the model file
opset_version=11, # ONNX version to export the model to
do_constant_folding=True, # Whether to execute constant folding for optimization
input_names=['input'], # Model's input names
output_names=['loc', 'conf', 'landmarks'], # Model's output names
dynamic_axes={
'input': {
0: 'batch_size',
2: 'height',
3: 'width'
},
'loc': {0: 'batch_size'}, # Location output
'conf': {0: 'batch_size'}, # Confidence output
'landmarks': {0: 'batch_size'} # Landmarks output
}
)
print(f"Model exported successfully to {onnx_model}")
if __name__ == '__main__':
args = parse_arguments()
onnx_export(args)