-
Notifications
You must be signed in to change notification settings - Fork 1
/
generator_old.py
60 lines (53 loc) · 2.31 KB
/
generator_old.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 19 15:03:23 2018
@author: sebastian
"""
import numpy as np
import torch
from torch.utils.data import Dataset
import h5py
class Generator(Dataset):
def __init__(self, file, batch_size, idcs):
self.X = file['X']
self.Y = file['Y']
input_size = self.X.shape[1:]
output_size = self.Y.shape[1:]
self.input_x = np.zeros((batch_size,)+input_size)
self.input_y = np.zeros((batch_size,1)+output_size)
self.batch_size = batch_size
self.input_size = input_size
self.idcs = idcs
self.idx = 0
def __len__(self):
return(self.batch_size)
def __getitem__(self, batch = None):
for batch in range(self.batch_size):
self.input_x[batch,:,:,:,:] = self.X[self.idcs[self.idx],:,:,:]
self.input_y[batch,0,:,:,:] = self.Y[self.idcs[self.idx],:,:,:]
self.idx = (self.idx+1)%(len(self.idcs))
return(torch.from_numpy(self.input_x.transpose([0,4,1,2,3]).astype(np.float32)).contiguous(),
torch.from_numpy(self.input_y.transpose((0,4,1,2,3)).astype(np.float32)).contiguous())
class Generator1(Dataset):
def __init__(self, file, batch_size, idcs, n_frames = 10):
self.n_frames = n_frames
self.X = file['X']
self.Y = file['Y']
input_size = self.X.shape[1:]
output_size = self.Y.shape[1:]
self.input_x = np.zeros((batch_size,n_frames)+input_size[1:])
self.input_y = np.zeros((batch_size,1)+output_size)
self.batch_size = batch_size
self.input_size = input_size
self.idcs = idcs
self.idx = 0
def __len__(self):
return(self.batch_size)
def __getitem__(self, batch = None):
for batch in range(self.batch_size):
self.input_x[batch,:,:,:,:] = self.X[self.idcs[self.idx],5-(self.n_frames//2):11-(5-(self.n_frames//2)),:,:,:]
self.input_y[batch,0,:,:,:] = self.Y[self.idcs[self.idx],:,:,:]
self.idx = (self.idx+1)%(len(self.idcs))
return(torch.from_numpy(self.input_x.transpose([0,4,1,2,3]).astype(np.float32)/255).contiguous(),
torch.from_numpy(self.input_y.transpose((0,1,4,2,3)).astype(np.float32)/255).contiguous())