-
Notifications
You must be signed in to change notification settings - Fork 1
/
generator.py
67 lines (53 loc) · 2.82 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 19 15:03:23 2018
@author: sebastian
"""
import numpy as np
import torch
from torch.utils.data import Dataset
from skimage.transform import warp,AffineTransform
class Generator(Dataset):
def __init__(self, X, Y, batch_size, n_frames = 5,size_x = None ,size_y = None ,randomize = True, rotate = True, mirror = True):
self.X = X
self.Y = Y
n_categories = len(Y[0].keys())
if not size_x:
[self.size_x,self.size_y] = X.shape[2:-1]
else:
[self.size_x,self.size_y] = [size_x,size_y]
input_size = (3,n_frames)+ tuple([self.size_x,self.size_y] )
self.input_x = np.zeros((batch_size,)+input_size)
self.input_y = np.zeros([batch_size,n_categories]+[self.size_x//2,self.size_y//2])#+X.shape[2:-1])
self.batch_size = batch_size
self.n_frames = n_frames
self.idcs = np.arange(X.shape[0]-input_size[-1]-1)
if randomize:
self.idcs = np.random.permutation(self.idcs)
self.idx = 0
def __len__(self):
return(self.batch_size)
def __getitem__(self, batch = None):
self.input_y[:,:,:,:] = 0
for sample in range(self.batch_size):
n = self.idcs[self.idx]
scale=0.8+np.random.normal()/10
rotation=np.random.randint(4)#np.random.uniform()*2
translation = [([0,1,1,0][rotation])*self.X.shape[2]*scale,(([0,0,1,1][rotation]))*self.X.shape[3]*scale]
translation[0]+=0 #np.random.randint(self.X.shape[2]-self.size_x)
translation[1]+=0 #np.random.randint(self.X.shape[3]-self.size_y)
rotation *= np.pi/2 +np.random.normal()/50
tform = AffineTransform(scale=[scale]*2, rotation=rotation,translation = translation)
start = 5-(self.n_frames//2)
for img in range(self.n_frames):
self.input_x[sample,:,img,:,:] = warp(self.X[n,start+img,:,:,:]/255, tform.inverse, output_shape=(self.size_x, self.size_y)).transpose([2,0,1])
for j,label in enumerate(self.Y[n].keys()):
x_s = [(l) for k,l in enumerate(self.Y[n][label][0])]
y_s = [(l) for k,l in enumerate(self.Y[n][label][1])]
for k in range(len(x_s)):
coords = tform([x_s[k],y_s[k]]).astype(np.int)[0]
if all(coords>=0) and all (coords<self.size_x):#all((self.X.shape[2:4][::-1]-coords)>0):
self.input_y[sample,j,coords[1]//2,coords[0]//2]=1
self.idx = (self.idx+1)%(len(self.idcs))
return(torch.from_numpy(self.input_x).float().cuda(),torch.from_numpy(self.input_y).float().cuda())