Skip to content

Commit

Permalink
refactored dataloading for new dataset, improved augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
fabiocarrara committed Dec 22, 2020
1 parent ddbbbb4 commit 46c4972
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 39 deletions.
176 changes: 140 additions & 36 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,149 @@
from scipy.ndimage import center_of_mass


def is_contained(ltrb, corners):
"""
A------B
| |
C------D
"""
A, B, C, D = corners

def is_in(X):
return (np.dot(C - A, X - A) > 0 and
np.dot(A - C, X - C) > 0 and
np.dot(B - A, X - A) > 0 and
np.dot(A - B, X - B) > 0)

l, t, r, b = ltrb
lt = np.array((l, t))
lb = np.array((l, b))
rt = np.array((r, t))
rb = np.array((r, b))

c1 = is_in(lt)
c2 = is_in(lb)
c3 = is_in(rt)
c4 = is_in(rb)

return c1 and c2 and c3 and c4

def load_xy(datum, x_shape, deterministic=True, no_pad=False):
x = load_img(datum.filename, color_mode='grayscale')
y = load_img(datum.target)
if datum[['pupil-y', 'pupil-x']].isnull().values.any():
pupil_map = img_to_array(y)[:, :, 0] # R channel = pupil

w, h = x.size

# find pupil center
def _get_pupil_position(pmap):
with np.errstate(invalid='raise'):
try:
pupil_pos_yx = center_of_mass(pupil_map)
return center_of_mass(pmap)
except:
pupil_pos_yx = (x_shape[0] / 2, x_shape[1] / 2)
else:
pupil_pos_yx = datum[['pupil-y', 'pupil-x']].values.astype(int)

# new pupil position: we want it to be in [0.1, 0.9] (in percentage height and width) 95% of times
pupil_new_pos = np.array([.5, .5]) if deterministic else np.random.normal(loc=0.5, scale=0.4, size=2)
# print('Center of mass not found, defaulting to image center:', datum.target)
return (x_shape[0] / 2, x_shape[1] / 2)

pupil_map = np.array(y)[:,:,0] # R channel = pupil
pupil_area = pupil_map.sum()

if deterministic:
pupil_new_pos = np.array([.5, .5])
s = 128
pupil_new_pos_yx = (pupil_new_pos * s).astype(int)
pupil_pos_yx = _get_pupil_position(pupil_map)
oy, ox = pupil_pos_yx - pupil_new_pos_yx
crop = (ox, oy, ox + s, oy + s) # Left, Upper; Right, Lower

if no_pad:
l, t, r, b = crop
dx = -l if l < 0 else w - r if r > w else 0
dy = -t if t < 0 else h - b if b > h else 0
crop = (l + dx, t + dy, r + dx, b + dy)
# the image may still be smaller than the crop area, adjusting ...
l, t, r, b = crop
crop = (max(0, l), max(0, t), min(w, r), min(h, b))

else: # random rotation, random pupil position, random scale

# pick random angle
angle = np.random.uniform(0, 90)
x = x.rotate(angle, expand=True)
y = y.rotate(angle, expand=True)

# find pupil in rotated image
pupil_map = np.array(y)[:, :, 0] # R channel = pupil
pupil_area = pupil_map.sum()
pupil_pos_yx = _get_pupil_position(pupil_map)


# find image corners in rotated image
theta = np.radians(angle)
cos_t, sin_t = np.cos(theta), np.sin(theta)
rot = np.array([[cos_t, sin_t], [-sin_t, cos_t]]) # build rotation for -theta (compensate flipped y-axis)
centered_corners = np.array([[-w / 2, -h / 2], [w / 2, -h / 2], [-w / 2, h / 2], [w / 2, h / 2]])
rotated_centered_corners = np.dot(centered_corners, rot.T)
rotated_corners = rotated_centered_corners - rotated_centered_corners.min(axis=0, keepdims=True)

# pick size of crop around the pupil for scale augmentation
# (this is constrained by the rotation angle and the original image size
min_s = 15
max_s = np.floor(min(w, h) / (sin_t + cos_t))
s = np.random.normal(loc=128, scale=40)
s = np.clip(s, min_s, max_s)

# print(angle, s)

A, B, C, D = rotated_corners

# find the feasibility region for the top-left corner of a square crop of size s
# the region is a rectangle MNOP
M = A + ((B - A) / w) * sin_t * s
N = B + ((A - B) / w) * cos_t * s
O = -s + D + ((C - D) / w) * sin_t * s
P = -s + C + ((D - C) / w) * cos_t * s
MNOP = np.stack((M,N,O,P))

# pick a new random position (in the crop space) in which to place the pupil center
pupil_new_pos_pct = np.random.normal(loc=0.5, scale=0.2, size=2)
pupil_new_pos_yx = (pupil_new_pos_pct * s).astype(int)
crop_top, crop_left = pupil_pos_yx - pupil_new_pos_yx
OC = np.array([crop_left, crop_top])

# ensure the crop origin is in the feasible region (MNOP)
## we do this in the feasibility region coordinate system (if xy in [0,1]^2, the crop is good):
## we first translate to M as new origin
MNOP_ = MNOP - M
OC_ = OC - M
M_,N_,O_,P_ = MNOP_

## the we use MN and MP as new basis
feasible2img = np.array([N_,P_]).T
img2feasible = np.linalg.inv(feasible2img)
OC_ = np.dot(img2feasible, OC_)

## we apply constraints in the new space and transform back
OC_ = np.clip(OC_, 0, 1)
crop_left, crop_top = np.dot(feasible2img, OC_) + M

crop = (crop_left, crop_top, crop_left + s, crop_top + s)

x = x.crop(crop)
y = y.crop(crop)

# size of crop around the pupil (for scale augmentation, needed?)
s = 128 if deterministic else int(np.random.normal(loc=128, scale=30))
# compute how much pupil is left in the image
new_pupil_map = np.array(y)[:, :, 0]
new_pupil_area = new_pupil_map.sum()

pupil_new_pos_yx = (pupil_new_pos * s).astype(int)
oy, ox = pupil_pos_yx - pupil_new_pos_yx
# print(new_pupil_area, pupil_area)
eye = (new_pupil_area / pupil_area) if pupil_area > 0 else 0

x = load_img(datum.filename, color_mode='grayscale')

crop = (ox, oy, ox + s, oy + s) # Left, Upper; Right, Lower
if no_pad:
w, h = x.size
l, t, r, b = crop
dx = -l if l < 0 else w - r if r > w else 0
dy = -t if t < 0 else h - b if b > h else 0
crop = (l + dx, t + dy, r + dx, b + dy)
# the image may still be smaller than the crop area, adjusting ...
l, t, r, b = crop
crop = (max(0, l), max(0, t), min(w, r), min(h, b))

x = x.crop(crop).resize(x_shape[:2]) # TODO: check interpolation type
y = y.crop(crop).resize(x_shape[:2])
if datum.eye & ~datum.blink:
datum.eye = eye

x = x.resize(x_shape[:2]) # TODO: check interpolation type
y = y.resize(x_shape[:2])

if not deterministic:
if not deterministic: # random flip
if np.random.rand() < .5:
x = x.transpose(Image.FLIP_LEFT_RIGHT)
y = y.transpose(Image.FLIP_LEFT_RIGHT)
Expand All @@ -56,15 +161,14 @@ def load_xy(datum, x_shape, deterministic=True, no_pad=False):
x = x.transpose(Image.FLIP_TOP_BOTTOM)
y = y.transpose(Image.FLIP_TOP_BOTTOM)

x = img_to_array(x) / 255.0
y = img_to_array(y)[:, :, :2] / 255.0 # keep only red and green channels

is_pupil_present = y[:, :, 0].sum() > 0
if datum.eye & ~datum.blink:
datum.eye = is_pupil_present
x = np.expand_dims(np.array(x), -1) / 255.0
y = np.array(y)[:, :, :2] / 255.0 # keep only red and green channels

y2 = datum[['eye', 'blink']]

# print('C = {}, crop = {}, s = {}, pnp = {}: '.format(rotated_corners, crop, s, pupil_new_pos), end='')
# print('E: {} B: {}'.format(datum.eye, datum.blink))

return pd.Series({'x': x, 'y': y, 'y2': y2})


Expand Down Expand Up @@ -101,7 +205,7 @@ def load_datasets(dataset_dirs):
def _load_and_prepare_annotations(dataset_dir):
data = os.path.join(dataset_dir, 'annotation', 'annotations.csv')
data = pd.read_csv(data)
data['target'] = dataset_dir + '/annotation/png/' + data.filename.str.replace(r'jpe?g', 'png')
data['target'] = dataset_dir + '/annotation/png/' + data.filename # .str.replace(r'jpe?g', 'png')
data['filename'] = dataset_dir + '/fullFrames/' + data.filename
return data

Expand Down
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(args):
x_shape = (args.resolution, args.resolution, 1)
y_shape = (args.resolution, args.resolution, 2)

train_gen = DataGen(train_data, x_shape=x_shape, batch_size=args.batch_size, no_pad=True)
train_gen = DataGen(train_data, x_shape=x_shape, batch_size=args.batch_size)
val_gen = DataGen(val_data, x_shape=x_shape, batch_size=args.batch_size, deterministic=True, no_pad=True)
test_gen = DataGen(test_data, x_shape=x_shape, batch_size=args.batch_size, deterministic=True, no_pad=True)

Expand Down Expand Up @@ -122,11 +122,12 @@ def main(args):


if __name__ == '__main__':
all_data = ['data/2p-dataset', 'data/H-dataset', 'data/NN_fullframe_extended', 'data/NN_mixed_dataset']
# default_data = ['data/2p-dataset', 'data/H-dataset', 'data/NN_fullframe_extended', 'data/NN_mixed_dataset']
default_data = ['data/NN_mixed_dataset_new']

parser = argparse.ArgumentParser(description='')
# data params
parser.add_argument('-d', '--data', nargs='+', default=all_data, help='Data directory (may be multiple)')
parser.add_argument('-d', '--data', nargs='+', default=default_data, help='Data directory (may be multiple)')
parser.add_argument('-r', '--resolution', type=int, default=128, help='Input image resolution')

# model params
Expand Down

0 comments on commit 46c4972

Please sign in to comment.