-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgen_onnx.py
78 lines (69 loc) · 2.54 KB
/
gen_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
import torch
from torch import nn as nn
from torch.nn import functional as F
# network of convolutions
class SRVGGNetCompact2(SRVGGNetCompact):
def forward(self, x):
out = x
for i in range(0, len(self.body)):
out = self.body[i](out)
out = self.upsampler(out)
return out
# module for composing final image
class ModelEnd(torch.nn.Module):
def __init__(self, upscale=4):
super(ModelEnd, self).__init__()
self.upscale = upscale
def forward(self, x, out):
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
out += base
return out
if __name__ == '__main__':
model_path = "weights/realesr-animevideov3.pth"
# Split esrgan network in two parts for webgl backend compatibility
modelPre = SRVGGNetCompact2(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
modelPre.load_state_dict(loadnet[keyname], strict=True)
modelPre.eval()
modelEnd = ModelEnd()
# Export convolutions network
dummy_input = torch.randn(1, 3, 300, 300)
dynamic_axes = {
"input": {2: "height", 3: "width"}
}
input_names = [ "input" ]
output_names = [ "output" ]
torch.onnx.export(modelPre,
(dummy_input),
"esrgan-small-pre.onnx",
verbose=True,
input_names=input_names,
output_names=output_names,
export_params=True,
dynamic_axes=dynamic_axes,
opset_version=12
)
# Export netowrk final module
dummy_input_pre = modelPre(dummy_input.detach()).detach()
print(dummy_input_pre.shape)
dynamic_axes = {
"input": {2: "height", 3: "width"},
"input_pre": {2: "height", 3: "width"}
}
input_names = [ "input", "input_pre" ]
output_names = [ "output" ]
torch.onnx.export(modelEnd,
(dummy_input,dummy_input_pre),
"esrgan-small-end.onnx",
verbose=True,
input_names=input_names,
output_names=output_names,
export_params=True,
dynamic_axes=dynamic_axes,
opset_version=12
)