From 46c4972a086e7ace987bc9f2eded0f06d1b00809 Mon Sep 17 00:00:00 2001 From: Fabio Carrara Date: Tue, 22 Dec 2020 17:24:29 +0100 Subject: [PATCH] refactored dataloading for new dataset, improved augmentation --- dataloader.py | 176 +++++++++++++++++++++++++++++++++++++++----------- train.py | 7 +- 2 files changed, 144 insertions(+), 39 deletions(-) diff --git a/dataloader.py b/dataloader.py index 021f0ba..2b86598 100644 --- a/dataloader.py +++ b/dataloader.py @@ -10,44 +10,149 @@ from scipy.ndimage import center_of_mass +def is_contained(ltrb, corners): + """ + A------B + | | + C------D + """ + A, B, C, D = corners + + def is_in(X): + return (np.dot(C - A, X - A) > 0 and + np.dot(A - C, X - C) > 0 and + np.dot(B - A, X - A) > 0 and + np.dot(A - B, X - B) > 0) + + l, t, r, b = ltrb + lt = np.array((l, t)) + lb = np.array((l, b)) + rt = np.array((r, t)) + rb = np.array((r, b)) + + c1 = is_in(lt) + c2 = is_in(lb) + c3 = is_in(rt) + c4 = is_in(rb) + + return c1 and c2 and c3 and c4 + def load_xy(datum, x_shape, deterministic=True, no_pad=False): + x = load_img(datum.filename, color_mode='grayscale') y = load_img(datum.target) - if datum[['pupil-y', 'pupil-x']].isnull().values.any(): - pupil_map = img_to_array(y)[:, :, 0] # R channel = pupil + + w, h = x.size + + # find pupil center + def _get_pupil_position(pmap): with np.errstate(invalid='raise'): try: - pupil_pos_yx = center_of_mass(pupil_map) + return center_of_mass(pmap) except: - pupil_pos_yx = (x_shape[0] / 2, x_shape[1] / 2) - else: - pupil_pos_yx = datum[['pupil-y', 'pupil-x']].values.astype(int) - - # new pupil position: we want it to be in [0.1, 0.9] (in percentage height and width) 95% of times - pupil_new_pos = np.array([.5, .5]) if deterministic else np.random.normal(loc=0.5, scale=0.4, size=2) + # print('Center of mass not found, defaulting to image center:', datum.target) + return (x_shape[0] / 2, x_shape[1] / 2) + + pupil_map = np.array(y)[:,:,0] # R channel = pupil + pupil_area = pupil_map.sum() + + if deterministic: + pupil_new_pos = np.array([.5, .5]) + s = 128 + pupil_new_pos_yx = (pupil_new_pos * s).astype(int) + pupil_pos_yx = _get_pupil_position(pupil_map) + oy, ox = pupil_pos_yx - pupil_new_pos_yx + crop = (ox, oy, ox + s, oy + s) # Left, Upper; Right, Lower + + if no_pad: + l, t, r, b = crop + dx = -l if l < 0 else w - r if r > w else 0 + dy = -t if t < 0 else h - b if b > h else 0 + crop = (l + dx, t + dy, r + dx, b + dy) + # the image may still be smaller than the crop area, adjusting ... + l, t, r, b = crop + crop = (max(0, l), max(0, t), min(w, r), min(h, b)) + + else: # random rotation, random pupil position, random scale + + # pick random angle + angle = np.random.uniform(0, 90) + x = x.rotate(angle, expand=True) + y = y.rotate(angle, expand=True) + + # find pupil in rotated image + pupil_map = np.array(y)[:, :, 0] # R channel = pupil + pupil_area = pupil_map.sum() + pupil_pos_yx = _get_pupil_position(pupil_map) + + + # find image corners in rotated image + theta = np.radians(angle) + cos_t, sin_t = np.cos(theta), np.sin(theta) + rot = np.array([[cos_t, sin_t], [-sin_t, cos_t]]) # build rotation for -theta (compensate flipped y-axis) + centered_corners = np.array([[-w / 2, -h / 2], [w / 2, -h / 2], [-w / 2, h / 2], [w / 2, h / 2]]) + rotated_centered_corners = np.dot(centered_corners, rot.T) + rotated_corners = rotated_centered_corners - rotated_centered_corners.min(axis=0, keepdims=True) + + # pick size of crop around the pupil for scale augmentation + # (this is constrained by the rotation angle and the original image size + min_s = 15 + max_s = np.floor(min(w, h) / (sin_t + cos_t)) + s = np.random.normal(loc=128, scale=40) + s = np.clip(s, min_s, max_s) + + # print(angle, s) + + A, B, C, D = rotated_corners + + # find the feasibility region for the top-left corner of a square crop of size s + # the region is a rectangle MNOP + M = A + ((B - A) / w) * sin_t * s + N = B + ((A - B) / w) * cos_t * s + O = -s + D + ((C - D) / w) * sin_t * s + P = -s + C + ((D - C) / w) * cos_t * s + MNOP = np.stack((M,N,O,P)) + + # pick a new random position (in the crop space) in which to place the pupil center + pupil_new_pos_pct = np.random.normal(loc=0.5, scale=0.2, size=2) + pupil_new_pos_yx = (pupil_new_pos_pct * s).astype(int) + crop_top, crop_left = pupil_pos_yx - pupil_new_pos_yx + OC = np.array([crop_left, crop_top]) + + # ensure the crop origin is in the feasible region (MNOP) + ## we do this in the feasibility region coordinate system (if xy in [0,1]^2, the crop is good): + ## we first translate to M as new origin + MNOP_ = MNOP - M + OC_ = OC - M + M_,N_,O_,P_ = MNOP_ + + ## the we use MN and MP as new basis + feasible2img = np.array([N_,P_]).T + img2feasible = np.linalg.inv(feasible2img) + OC_ = np.dot(img2feasible, OC_) + + ## we apply constraints in the new space and transform back + OC_ = np.clip(OC_, 0, 1) + crop_left, crop_top = np.dot(feasible2img, OC_) + M + + crop = (crop_left, crop_top, crop_left + s, crop_top + s) + + x = x.crop(crop) + y = y.crop(crop) - # size of crop around the pupil (for scale augmentation, needed?) - s = 128 if deterministic else int(np.random.normal(loc=128, scale=30)) + # compute how much pupil is left in the image + new_pupil_map = np.array(y)[:, :, 0] + new_pupil_area = new_pupil_map.sum() - pupil_new_pos_yx = (pupil_new_pos * s).astype(int) - oy, ox = pupil_pos_yx - pupil_new_pos_yx + # print(new_pupil_area, pupil_area) + eye = (new_pupil_area / pupil_area) if pupil_area > 0 else 0 - x = load_img(datum.filename, color_mode='grayscale') - - crop = (ox, oy, ox + s, oy + s) # Left, Upper; Right, Lower - if no_pad: - w, h = x.size - l, t, r, b = crop - dx = -l if l < 0 else w - r if r > w else 0 - dy = -t if t < 0 else h - b if b > h else 0 - crop = (l + dx, t + dy, r + dx, b + dy) - # the image may still be smaller than the crop area, adjusting ... - l, t, r, b = crop - crop = (max(0, l), max(0, t), min(w, r), min(h, b)) - - x = x.crop(crop).resize(x_shape[:2]) # TODO: check interpolation type - y = y.crop(crop).resize(x_shape[:2]) + if datum.eye & ~datum.blink: + datum.eye = eye + + x = x.resize(x_shape[:2]) # TODO: check interpolation type + y = y.resize(x_shape[:2]) - if not deterministic: + if not deterministic: # random flip if np.random.rand() < .5: x = x.transpose(Image.FLIP_LEFT_RIGHT) y = y.transpose(Image.FLIP_LEFT_RIGHT) @@ -56,15 +161,14 @@ def load_xy(datum, x_shape, deterministic=True, no_pad=False): x = x.transpose(Image.FLIP_TOP_BOTTOM) y = y.transpose(Image.FLIP_TOP_BOTTOM) - x = img_to_array(x) / 255.0 - y = img_to_array(y)[:, :, :2] / 255.0 # keep only red and green channels - - is_pupil_present = y[:, :, 0].sum() > 0 - if datum.eye & ~datum.blink: - datum.eye = is_pupil_present + x = np.expand_dims(np.array(x), -1) / 255.0 + y = np.array(y)[:, :, :2] / 255.0 # keep only red and green channels y2 = datum[['eye', 'blink']] + # print('C = {}, crop = {}, s = {}, pnp = {}: '.format(rotated_corners, crop, s, pupil_new_pos), end='') + # print('E: {} B: {}'.format(datum.eye, datum.blink)) + return pd.Series({'x': x, 'y': y, 'y2': y2}) @@ -101,7 +205,7 @@ def load_datasets(dataset_dirs): def _load_and_prepare_annotations(dataset_dir): data = os.path.join(dataset_dir, 'annotation', 'annotations.csv') data = pd.read_csv(data) - data['target'] = dataset_dir + '/annotation/png/' + data.filename.str.replace(r'jpe?g', 'png') + data['target'] = dataset_dir + '/annotation/png/' + data.filename # .str.replace(r'jpe?g', 'png') data['filename'] = dataset_dir + '/fullFrames/' + data.filename return data diff --git a/train.py b/train.py index 2750a28..5047e1a 100644 --- a/train.py +++ b/train.py @@ -37,7 +37,7 @@ def main(args): x_shape = (args.resolution, args.resolution, 1) y_shape = (args.resolution, args.resolution, 2) - train_gen = DataGen(train_data, x_shape=x_shape, batch_size=args.batch_size, no_pad=True) + train_gen = DataGen(train_data, x_shape=x_shape, batch_size=args.batch_size) val_gen = DataGen(val_data, x_shape=x_shape, batch_size=args.batch_size, deterministic=True, no_pad=True) test_gen = DataGen(test_data, x_shape=x_shape, batch_size=args.batch_size, deterministic=True, no_pad=True) @@ -122,11 +122,12 @@ def main(args): if __name__ == '__main__': - all_data = ['data/2p-dataset', 'data/H-dataset', 'data/NN_fullframe_extended', 'data/NN_mixed_dataset'] + # default_data = ['data/2p-dataset', 'data/H-dataset', 'data/NN_fullframe_extended', 'data/NN_mixed_dataset'] + default_data = ['data/NN_mixed_dataset_new'] parser = argparse.ArgumentParser(description='') # data params - parser.add_argument('-d', '--data', nargs='+', default=all_data, help='Data directory (may be multiple)') + parser.add_argument('-d', '--data', nargs='+', default=default_data, help='Data directory (may be multiple)') parser.add_argument('-r', '--resolution', type=int, default=128, help='Input image resolution') # model params