-
Notifications
You must be signed in to change notification settings - Fork 0
/
draft_run_training.py
161 lines (127 loc) · 6.84 KB
/
draft_run_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import argparse
from utils.timers import Timer
from utils.loading_utils import load_model, get_device
from utils.event_readers import FixedSizeEventReader, FixedDurationEventReader
from utils.inference_utils import events_to_voxel_grid_pytorch
from options.inference_options import set_inference_options
from utils.ecoco_sequence_loader import *
from utils.train_utils import plot_training_data, pad_all, loss_fn, training_loop
import lpips
from utils.inference_utils import IntensityRescaler
from image_reconstructor import ImageReconstructor
if __name__ == "__main__":
# ======================================================================================================================================================
# Model definition
parser = argparse.ArgumentParser(
description='Evaluating a trained network')
parser.add_argument('-c', '--path_to_model', required=True, type=str,
help='path to model weights')
parser.add_argument('-i', '--input_file', required=True, type=str)
parser.add_argument('--fixed_duration', dest='fixed_duration', action='store_true')
parser.set_defaults(fixed_duration=False)
# parser.add_argument('-N', '--window_size', default=None, type=int,
# help="Size of each event window, in number of events. Ignored if --fixed_duration=True")
# parser.add_argument('-T', '--window_duration', default=33.33, type=float,
# help="Duration of each event window, in milliseconds. Ignored if --fixed_duration=False")
# parser.add_argument('--num_events_per_pixel', default=0.35, type=float,
# help='in case N (window size) is not specified, it will be \
# automatically computed as N = width * height * num_events_per_pixel')
parser.add_argument('--skipevents', default=0, type=int)
parser.add_argument('--suboffset', default=0, type=int)
# parser.add_argument('--compute_voxel_grid_on_cpu', dest='compute_voxel_grid_on_cpu', action='store_true')
# parser.set_defaults(compute_voxel_grid_on_cpu=False)
set_inference_options(parser)
args = parser.parse_args()
# Read sensor size from the first line of the event file
path_to_events = args.input_file
header = pd.read_csv(path_to_events, delim_whitespace=True, header=None, names=['width', 'height'],
dtype={'width': np.int, 'height': np.int},
nrows=1)
width, height = header.values[0]
print('Sensor size: {} x {}'.format(width, height))
# Load model
device = get_device(args.use_gpu)
model = load_model(args.path_to_model, map_location = device)
model = model.to(device)
# model.eval()
#
# reconstructor = ImageReconstructor(model, height, width, model.num_bins, args)
N = 15119
""" Read chunks of events using Pandas """
initial_offset = args.skipevents
sub_offset = args.suboffset
start_index = initial_offset + sub_offset
if args.fixed_duration:
event_window_iterator = FixedDurationEventReader(path_to_events,
duration_ms=args.window_duration,
start_index=start_index)
else:
event_window_iterator = FixedSizeEventReader(path_to_events, num_events=N, start_index=start_index)
with Timer('Processing entire dataset'):
counter = 0
for event_window in event_window_iterator:
last_timestamp = event_window[-1, 0]
with Timer('Building event tensor'):
event_tensor = events_to_voxel_grid_pytorch(event_window,
num_bins=model.num_bins,
width=width,
height=height,
device=device)
num_events_in_window = event_window.shape[0]
start_index += num_events_in_window
counter += 1
if counter >= 1:
break
# # Do not worry about code above!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# #============================================================================================================================================
# # ignore the code above, they are just used for taking out the event tensor and model
# device = get_device(True)
# #DATA_DIR = '/home/richard/Q3/Deep_Learning/ruben-mr.github.io/data'
#
# batch_size = 2
# sequence_length = 5
# events = torch.tensor(full_event_tensor(range(batch_size), sequence_length, DATA_DIR)[0],
# dtype=torch.float64).cuda().float()
# images = torch.tensor(full_image_tensor(range(batch_size), sequence_length, DATA_DIR)[0],
# dtype=torch.float64).cuda().float()
#
# #=============================
# # data pre-processing
# events, images = pad_all(model, events, images)
# #=============================
# train_loader = [(events, images)]
# validation_loader = [(events, images)]
#
# if torch.cuda.is_available():
# reconstruction_loss_fn = lpips.LPIPS(net='vgg').cuda()
# else:
# reconstruction_loss_fn = lpips.LPIPS(net='vgg')
#
# train_losses, val_losses = training_loop(model, loss_fn, train_loader, validation_loader, reconstruction_loss_fn, epoch=5)
# plot_training_data(train_losses, val_losses)
# ====================================================
# old test code
# ignore the code above, they are just used for taking out the event tensor and model
# let's make a pseudo-dataset!!
device = torch.device('cuda:0')
events = event_tensor.unsqueeze(dim=0)
# ==========================
# pre-processing step here (normalizing and padding)
crop = CropParameters(width, height, model.num_encoders)
events = crop.pad(events) # (1, 5, 184, 240)
# ==========================================
#events = events.view((1,*events.shape)) # (1, 1, 5, 184, 240)
events = events.unsqueeze(dim=0)
sequence_length = events.shape[0]
batch_size = events.shape[1]
events = events.tile((sequence_length, batch_size, 1, 1, 1)) # (sequence_len, batch_size, channel, H, W)
events = events.to(device)
labels = torch.rand(events.shape).detach()
labels = labels[:, :, 0:1, :, :] # TODO: dealing with multiple channels
labels = labels.to(device)
train_loader = [(events, labels)]
validation_loader = [(events.detach(), labels)]
#============================================================================
train_losses, val_losses = training_loop(model, loss_fn, train_loader, validation_loader)
plot_training_data(train_losses, val_losses)
#===================================================