-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_video.py
executable file
·107 lines (84 loc) · 2.96 KB
/
model_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#
#Copyright (C) 2020-2021 ISTI-CNR
#Licensed under the BSD 3-Clause Clear License (see license.txt)
#
import os
import torch
import glob2
import re
import torch.nn as nn
from model_decoder_lstm import DecoderLSTM
from model_encoder_resnet import EncoderResNet
#
#
#
class ModelVideo(nn.Module):
#
#
#
def __init__(self, run, device, differential = 0):
super(ModelVideo, self).__init__()
#create the model
self.cnn_encoder = EncoderResNet(1024, 768, 0.2, 512, differential).to(device)
self.lstm_decoder = DecoderLSTM().to(device)
if run != '':
print('Resume ModelVideo')
try:
ext = os.path.splitext(run)[1]
if (ext == ''):
#load the model
self.run = run
ckpt_dir = os.path.join(run, 'ckpt')
ckpt_name = os.path.join(ckpt_dir, '*.pth')
print(ckpt_dir)
print(ckpt_name)
ckpts = glob2.glob(ckpt_name)
assert ckpts, "No checkpoints to resume from!"
def get_epoch(ckpt_url):
s = re.findall("ckpt_e(\d+).pth", ckpt_url)
epoch = int(s[0]) if s else -1
return epoch, ckpt_url
print(ckpts)
start_epoch, ckpt = max(get_epoch(c) for c in ckpts)
print('Checkpoint:', ckpt)
else:
ckpt = run
bCuda = torch.cuda.is_available() # do we have a CUDA GPU?
device = torch.device("cuda" if bCuda else "cpu")
ckpt = torch.load(ckpt, map_location = device, weights_only=True )
c0 = ckpt['cnn_model']
try:
c1 = ckpt['lstm_model']
except:
c1 = ckpt['rnn_model']
self.cnn_encoder.load_state_dict(c0)
self.lstm_decoder.load_state_dict(c1)
except:
print('No model to resume')
self.cnn_encoder.eval()
self.lstm_decoder.eval()
self.device = device
#
#
#
def forward(self, x):
output_cnn = self.cnn_encoder(x)
output = self.lstm_decoder(output_cnn)
return output
#
#
#
def predict(self, X):
self.eval()
X = X.to(self.device)
if (len(X.shape) == 4):
X = torch.unsqueeze(X, 0)
return self.predictSimple(X)
#
#
#
def predictSimple(self, X):
with torch.no_grad():
output_cnn = self.cnn_encoder(X)
output = self.lstm_decoder(output_cnn)
return output.data.cpu().item()