diff --git a/README.md b/README.md index b259d52..dce6408 100644 --- a/README.md +++ b/README.md @@ -29,42 +29,70 @@ see details in `pyoroject.toml` from tf_raft.model import RAFT, SmallRAFT from tf_raft.losses import sequence_loss, end_point_error -# iters means number of recurrent update of flow -raft = RAFT(iters=iters) +# iters/iters_pred are the number of recurrent update of flow in training/prediction +raft = RAFT(iters=iters, iters_pred=iters_pred) raft.compile( optimizer=optimizer, + clip_norm=clip_norm, loss=sequence_loss, epe=end_point_error ) raft.fit( - dataset, + ds_train, epochs=epochs, callbacks=callbacks, + steps_per_epoch=train_size//batch_size, + validation_data=ds_val, + validation_steps=val_size ) ``` -In practice, you are required to prepare dataset, optimizer, callbacks etc, check details in `train.py`. +In practice, you are required to prepare dataset, optimizer, callbacks etc, check details in `train_sintel.py` or `train_chairs.py`. -## Load the pretrained weights +### Train via YAML configuration -You can download the pretrained weights via `gsutil` or `curl` (trained on MPI-Sintel Clean, and FlyingChairs) +`train_chairs.py` and `train_sintel.py` train RAFT model via YAML configuration. Sample configs are in `configs` directory. Run; ``` shell -$ gsutil cp -r gs://tf-raft-pretrained/checkpoints . +$ python train_chairs.py /path/to/config.yml +``` + +## Pre-trained models + +I made the pre-trained weights (on both FlyingChairs and MPI-Sintel) public. +You can download them via `gsutil` or `curl`. + +### Trained weights on FlyingChairs + +``` shell +$ gsutil cp -r gs://tf-raft-pretrained/2020-09-26T18-38/checkpoints . +``` +or +``` shell +$ mkdir checkpoints +$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T18-38/checkpoints/model.data-00000-of-00001 +$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T18-38/checkpoints/model.index +$ mv model* checkpoints/ +``` + +### Trained weights on MPI-Sintel (Clean path) + +``` shell +$ gsutil cp -r gs://tf-raft-pretrained/2020-09-26T08-51/checkpoints . ``` or ``` shell $ mkdir checkpoints -$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/checkpoints/model.data-00000-of-00001 -$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/checkpoints/model.index +$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T08-51/checkpoints/model.data-00000-of-00001 +$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T08-51/checkpoints/model.index $ mv model* checkpoints/ ``` -then +### Load weights ``` python -raft = RAFT(iters=iters) +raft = RAFT(iters=iters, iters_pred=iters_pred) raft.load_weights('checkpoints/model') # forward (with dummy inputs) @@ -76,11 +104,13 @@ print(flow_predictions[-1].shape) # >> (1, 448, 512, 2) ``` ## Note -Though I have tried to reproduce the original implementation faithfully, there is some difference between it and my implementation (mainly because of used framework: PyTorch/TensorFlow); +Though I have tried to reproduce the original implementation faithfully, there is some difference between the original one and mine (mainly because of used framework: PyTorch/TensorFlow); + +- The original implementations provides cuda-based correlation function but I don't. My TF-based implementation works well, but cuda-based one may run faster. +- I have trained my model on FlyingChairs and MPI-Sintel separately in my private environment (GCP with P100 accelerator). The model has been trained well, but not reached the best score reported in the paper (trained on multiple datasets). +- The original one uses mixed-precision. This may get training much faster, but I don't. TensorFlow also enables mixed-precision with few additional lines, see https://www.tensorflow.org/guide/mixed_precision if interested. -- The original implements cuda-based correlation function but I don't. My TF-based implementation works well, but cuda-based one may runs faster. -- I have trained my model only on MPI-Sintel dataset in my private environment (GCP with P100 accelerator). The model has been trained well, but not reached the best score reported in the paper (trained on multiple datasets). -- The original uses mixed-precision. This may get traininig much faster, but I don't. TensorFlow also enables mixed-precision with few additional lines, see https://www.tensorflow.org/guide/mixed_precision if interested. +Additional, global gradient clipping seems to be essential for stable training though it is not emphasized in the original paper. This operation can be done via `torch.nn.utils.clip_grad_norm_(model.parameters(), clip)` in PyTorch, `tf.clip_by_global_norm(grads, clip_norm)` in TF (coded at `self.train_step` in `tf_raft/model.py`). ## References - https://github.com/princeton-vl/RAFT diff --git a/configs/train_chairs.yml b/configs/train_chairs.yml new file mode 100644 index 0000000..9b890dc --- /dev/null +++ b/configs/train_chairs.yml @@ -0,0 +1,27 @@ +data: + root: './FlyingChairs_release/data' + split_txt: './FlyingChairs_release/FlyingChairs_train_val.txt' + +augment: + crop_size: [368, 496] + min_scale: -0.1 + max_scale: 1.0 + do_flip: True + +model: + iters: 12 + iters_pred: 24 + resume: 0 # '/path/to/checkpoint' + +train: + epochs: 1 + batch_size: 4 + learning_rate: 0.0004 + weight_decay: 0.0001 + clip_norm: 1 + +visualize: + num_visualize: 4 + choose_random: True + +logdir: './logs' diff --git a/configs/train_sintel.yml b/configs/train_sintel.yml new file mode 100644 index 0000000..92c0721 --- /dev/null +++ b/configs/train_sintel.yml @@ -0,0 +1,26 @@ +data: + root: './MPI-Sintel-complete' + +augment: + crop_size: [368, 768] + min_scale: -0.2 + max_scale: 0.6 + do_flip: True + +model: + iters: 12 + iters_pred: 24 + resume: 0 # '/path/to/checkpoint' + +train: + epochs: 1 + batch_size: 4 + learning_rate: 0.00012 + weight_decay: 0.00001 + clip_norm: 1 + +visualize: + num_visualize: 4 + choose_random: True + +logdir: './logs' diff --git a/predicted_flows/epoch001_bamboo_3_17.png b/predicted_flows/epoch001_bamboo_3_17.png deleted file mode 100644 index 29f815c..0000000 Binary files a/predicted_flows/epoch001_bamboo_3_17.png and /dev/null differ diff --git a/predicted_flows/epoch001_cave_3_37.png b/predicted_flows/epoch001_cave_3_37.png deleted file mode 100644 index 4933da8..0000000 Binary files a/predicted_flows/epoch001_cave_3_37.png and /dev/null differ diff --git a/predicted_flows/epoch001_market_1_42.png b/predicted_flows/epoch001_market_1_42.png deleted file mode 100644 index 5ff32f2..0000000 Binary files a/predicted_flows/epoch001_market_1_42.png and /dev/null differ diff --git a/predicted_flows/epoch001_temple_1_0.png b/predicted_flows/epoch001_temple_1_0.png deleted file mode 100644 index 96538de..0000000 Binary files a/predicted_flows/epoch001_temple_1_0.png and /dev/null differ diff --git a/predicted_flows/epoch002_PERTURBED_market_3_26.png b/predicted_flows/epoch002_PERTURBED_market_3_26.png deleted file mode 100644 index 470dbe5..0000000 Binary files a/predicted_flows/epoch002_PERTURBED_market_3_26.png and /dev/null differ diff --git a/predicted_flows/epoch002_PERTURBED_shaman_1_2.png b/predicted_flows/epoch002_PERTURBED_shaman_1_2.png deleted file mode 100644 index 2eb87c4..0000000 Binary files a/predicted_flows/epoch002_PERTURBED_shaman_1_2.png and /dev/null differ diff --git a/predicted_flows/epoch002_mountain_2_10.png b/predicted_flows/epoch002_mountain_2_10.png deleted file mode 100644 index 52b569f..0000000 Binary files a/predicted_flows/epoch002_mountain_2_10.png and /dev/null differ diff --git a/predicted_flows/epoch002_wall_18.png b/predicted_flows/epoch002_wall_18.png deleted file mode 100644 index e6864d0..0000000 Binary files a/predicted_flows/epoch002_wall_18.png and /dev/null differ diff --git a/predicted_flows/epoch003_PERTURBED_market_3_24.png b/predicted_flows/epoch003_PERTURBED_market_3_24.png deleted file mode 100644 index b25e2a7..0000000 Binary files a/predicted_flows/epoch003_PERTURBED_market_3_24.png and /dev/null differ diff --git a/predicted_flows/epoch003_bamboo_3_25.png b/predicted_flows/epoch003_bamboo_3_25.png deleted file mode 100644 index fbb6c74..0000000 Binary files a/predicted_flows/epoch003_bamboo_3_25.png and /dev/null differ diff --git a/predicted_flows/epoch003_bamboo_3_31.png b/predicted_flows/epoch003_bamboo_3_31.png deleted file mode 100644 index 3690c00..0000000 Binary files a/predicted_flows/epoch003_bamboo_3_31.png and /dev/null differ diff --git a/predicted_flows/epoch003_market_1_11.png b/predicted_flows/epoch003_market_1_11.png deleted file mode 100644 index 7b266ee..0000000 Binary files a/predicted_flows/epoch003_market_1_11.png and /dev/null differ diff --git a/predicted_flows/epoch004_PERTURBED_market_3_7.png b/predicted_flows/epoch004_PERTURBED_market_3_7.png deleted file mode 100644 index 4b2a973..0000000 Binary files a/predicted_flows/epoch004_PERTURBED_market_3_7.png and /dev/null differ diff --git a/predicted_flows/epoch004_cave_3_12.png b/predicted_flows/epoch004_cave_3_12.png deleted file mode 100644 index 71523d0..0000000 Binary files a/predicted_flows/epoch004_cave_3_12.png and /dev/null differ diff --git a/predicted_flows/epoch004_market_4_25.png b/predicted_flows/epoch004_market_4_25.png deleted file mode 100644 index be7184a..0000000 Binary files a/predicted_flows/epoch004_market_4_25.png and /dev/null differ diff --git a/predicted_flows/epoch004_wall_30.png b/predicted_flows/epoch004_wall_30.png deleted file mode 100644 index 6ebf7fc..0000000 Binary files a/predicted_flows/epoch004_wall_30.png and /dev/null differ diff --git a/predicted_flows/epoch005_bamboo_3_36.png b/predicted_flows/epoch005_bamboo_3_36.png deleted file mode 100644 index 9e0cd8f..0000000 Binary files a/predicted_flows/epoch005_bamboo_3_36.png and /dev/null differ diff --git a/predicted_flows/epoch005_market_1_31.png b/predicted_flows/epoch005_market_1_31.png deleted file mode 100644 index 2909869..0000000 Binary files a/predicted_flows/epoch005_market_1_31.png and /dev/null differ diff --git a/predicted_flows/epoch005_market_4_45.png b/predicted_flows/epoch005_market_4_45.png deleted file mode 100644 index 2994ae2..0000000 Binary files a/predicted_flows/epoch005_market_4_45.png and /dev/null differ diff --git a/predicted_flows/epoch005_mountain_2_2.png b/predicted_flows/epoch005_mountain_2_2.png deleted file mode 100644 index 9509254..0000000 Binary files a/predicted_flows/epoch005_mountain_2_2.png and /dev/null differ diff --git a/samples_sintel/epoch030_204.png b/samples_sintel/epoch030_204.png new file mode 100644 index 0000000..bf33d59 Binary files /dev/null and b/samples_sintel/epoch030_204.png differ diff --git a/samples_sintel/epoch030_362.png b/samples_sintel/epoch030_362.png new file mode 100644 index 0000000..acbcd6f Binary files /dev/null and b/samples_sintel/epoch030_362.png differ diff --git a/samples_sintel/epoch030_899.png b/samples_sintel/epoch030_899.png new file mode 100644 index 0000000..02d4b10 Binary files /dev/null and b/samples_sintel/epoch030_899.png differ diff --git a/samples_sintel/epoch030_988.png b/samples_sintel/epoch030_988.png new file mode 100644 index 0000000..00c5926 Binary files /dev/null and b/samples_sintel/epoch030_988.png differ diff --git a/tests/losses/test_losses.py b/tests/losses/test_losses.py index 8d927b6..0abc2fa 100644 --- a/tests/losses/test_losses.py +++ b/tests/losses/test_losses.py @@ -58,7 +58,7 @@ def test_end_point_error(data): u3_valid = 3 / 8 u5_valid = 5 / 8 - info = end_point_error([flow_gt, valid], predictions) + info = end_point_error([flow_gt, valid], predictions[-1]) # 2 decimal for sqrt precision np.testing.assert_almost_equal(info['epe'], epe_valid, decimal=2) diff --git a/tests/test_model.py b/tests/test_model.py index 9c44d60..7b61eca 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -43,33 +43,35 @@ def test_upsample_flow(): def test_raft(): iters = 6 + iters_pred = 12 image1 = tf.random.normal((BATCH_SIZE, *IMAGE_SIZE, 3)) image2 = tf.random.normal((BATCH_SIZE, *IMAGE_SIZE, 3)) - model = RAFT(drop_rate=0.0, iters=iters) + model = RAFT(drop_rate=0.0, iters=iters, iters_pred=iters_pred) output = model([image1, image2], training=True) assert len(output) == iters for flow in output: assert flow.shape == (BATCH_SIZE, *IMAGE_SIZE, 2) output = model([image1, image2], training=False) - assert len(output) == iters + assert len(output) == iters_pred for flow in output: assert flow.shape == (BATCH_SIZE, *IMAGE_SIZE, 2) def test_small_raft(): iters = 6 + iters_pred = 12 image1 = tf.random.normal((BATCH_SIZE, *IMAGE_SIZE, 3)) image2 = tf.random.normal((BATCH_SIZE, *IMAGE_SIZE, 3)) - model = SmallRAFT(drop_rate=0.0, iters=iters) + model = SmallRAFT(drop_rate=0.0, iters=iters, iters_pred=iters_pred) output = model([image1, image2], training=True) assert len(output) == iters for flow in output: assert flow.shape == (BATCH_SIZE, *IMAGE_SIZE, 2) output = model([image1, image2], training=False) - assert len(output) == iters + assert len(output) == iters_pred for flow in output: assert flow.shape == (BATCH_SIZE, *IMAGE_SIZE, 2) diff --git a/tf_raft/datasets/augmentor.py b/tf_raft/datasets/augmentor.py index 493a6bc..df0ebbf 100644 --- a/tf_raft/datasets/augmentor.py +++ b/tf_raft/datasets/augmentor.py @@ -1,16 +1,9 @@ import numpy as np -import random -import math -from PIL import Image - import cv2 cv2.setNumThreads(0) cv2.ocl.setUseOpenCL(False) - -# import torch -# from torchvision.transforms import ColorJitter -# import torch.nn.functional as F import albumentations as A +from PIL import Image class FlowAugmentor: @@ -37,7 +30,7 @@ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): contrast_limit=0.4 ), A.HueSaturationValue( - hue_shift_limit=int(0.5/3.14*255), + hue_shift_limit=int(0.5/3.14*180), sat_shift_limit=int(0.4*255), val_shift_limit=int(0.) ) @@ -135,6 +128,7 @@ def __call__(self, img1, img2, flow): return img1, img2, flow + class SparseFlowAugmentor: def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): # spatial augmentation params @@ -151,13 +145,25 @@ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): self.v_flip_prob = 0.1 # photometric augmentation params - self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + # self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.photo_aug = A.Compose([ + A.RandomBrightnessContrast( + brightness_limit=0.3, + contrast_limit=0.3 + ), + A.HueSaturationValue( + hue_shift_limit=int(0.3/3.14*180), + sat_shift_limit=int(0.3*255), + val_shift_limit=int(0.) + ) + ]) self.asymmetric_color_aug_prob = 0.2 self.eraser_aug_prob = 0.5 def color_transform(self, img1, img2): image_stack = np.concatenate([img1, img2], axis=0) - image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + # image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + image_stack = self.photo_aug(image=image_stack)['image'] img1, img2 = np.split(image_stack, 2, axis=0) return img1, img2 @@ -248,7 +254,6 @@ def spatial_transform(self, img1, img2, flow, valid): valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] return img1, img2, flow, valid - def __call__(self, img1, img2, flow, valid): img1, img2 = self.color_transform(img1, img2) img1, img2 = self.eraser_transform(img1, img2) diff --git a/tf_raft/losses/losses.py b/tf_raft/losses/losses.py index 6924954..83eca7d 100644 --- a/tf_raft/losses/losses.py +++ b/tf_raft/losses/losses.py @@ -28,7 +28,7 @@ def end_point_error(y_true, y_pred, max_flow=400): mag = tf.sqrt(tf.reduce_sum(flow_gt**2, axis=-1)) valid = valid & (mag < max_flow) - epe = tf.sqrt(tf.reduce_sum((y_pred[-1] - flow_gt)**2, axis=-1)) + epe = tf.sqrt(tf.reduce_sum((y_pred - flow_gt)**2, axis=-1)) epe = epe[valid] epe_under1 = tf.cast(epe < 1, dtype=tf.float32) epe_under3 = tf.cast(epe < 3, dtype=tf.float32) diff --git a/tf_raft/model.py b/tf_raft/model.py index ebc7cc4..40d0a0e 100644 --- a/tf_raft/model.py +++ b/tf_raft/model.py @@ -8,7 +8,7 @@ class RAFT(tf.keras.Model): - def __init__(self, drop_rate=0, iters=12, **kwargs): + def __init__(self, drop_rate=0, iters=12, iters_pred=24, **kwargs): super().__init__(**kwargs) self.hidden_dim = hdim = 128 @@ -19,6 +19,7 @@ def __init__(self, drop_rate=0, iters=12, **kwargs): self.drop_rate = drop_rate self.iters = iters + self.iters_pred = iters_pred self.fnet = BasicEncoder(output_dim=256, norm_type='instance', @@ -88,7 +89,8 @@ def call(self, inputs, training): coords0, coords1 = self.initialize_flow(image1) flow_predictions = [] - for i in range(self.iters): + iters = self.iters if training else self.iters_pred + for i in range(iters): # (bs, h, w, 81xnum_levels) corr = correlation.retrieve(coords1) @@ -106,9 +108,10 @@ def call(self, inputs, training): # flow_predictions[-1] is the finest output return flow_predictions - def compile(self, optimizer, loss, epe, **kwargs): + def compile(self, optimizer, clip_norm, loss, epe, **kwargs): super().compile(**kwargs) self.optimizer = optimizer + self.clip_norm = clip_norm self.loss = loss self.epe = epe @@ -129,9 +132,10 @@ def train_step(self, data): flow_predictions = self([image1, image2], training=True) loss = self.loss([flow, valid], flow_predictions) grads = tape.gradient(loss, self.trainable_weights) + grads, _ = tf.clip_by_global_norm(grads, self.clip_norm) self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) - info = self.epe([flow, valid], flow_predictions) + info = self.epe([flow, valid], flow_predictions[-1]) self.flow_metrics['loss'].update_state(loss) self.flow_metrics['epe'].update_state(info['epe']) self.flow_metrics['u1'].update_state(info['u1']) @@ -145,10 +149,8 @@ def test_step(self, data): image2 = tf.cast(image2, dtype=tf.float32) flow_predictions = self([image1, image2], training=False) - loss = self.loss([flow, valid], flow_predictions) - info = self.epe([flow, valid], flow_predictions) - self.flow_metrics['loss'].update_state(loss) + info = self.epe([flow, valid], flow_predictions[-1]) self.flow_metrics['epe'].update_state(info['epe']) self.flow_metrics['u1'].update_state(info['u1']) self.flow_metrics['u3'].update_state(info['u3']) @@ -169,8 +171,8 @@ def reset_metrics(self): class SmallRAFT(RAFT): - def __init__(self, drop_rate=0, iters=12, **kwargs): - super().__init__(drop_rate, iters, **kwargs) + def __init__(self, drop_rate=0, iters=12, iters_pred=24, **kwargs): + super().__init__(drop_rate, iters, iters_pred, **kwargs) self.hidden_dim = hdim = 96 self.context_dim = cdim = 64 @@ -207,7 +209,8 @@ def call(self, inputs, training): coords0, coords1 = self.initialize_flow(image1) flow_predictions = [] - for i in range(self.iters): + iters = self.iters if training else self.iters_pred + for i in range(iters): corr = correlation.retrieve(coords1) flow = coords1 - coords0 diff --git a/tf_raft/training.py b/tf_raft/training.py index 30ff6d3..c9bdd9c 100644 --- a/tf_raft/training.py +++ b/tf_raft/training.py @@ -61,7 +61,7 @@ def on_epoch_end(self, epoch, logs=None): vis_ids = range(self.num_visualize) for i in vis_ids: - image1, image2, (scene, index) = self.dataset[i] + image1, image2, *_ = self.dataset[i] if len(image1.shape) > 3: raise ValueError('target dataset must not be batched') @@ -83,6 +83,6 @@ def on_epoch_end(self, epoch, logs=None): contents = np.concatenate([image1, image2, flow_img], axis=0) - filename = f'epoch{str(epoch+1).zfill(3)}_{scene}_{index}.png' + filename = f'epoch{str(epoch+1).zfill(3)}_{str(i+1).zfill(3)}.png' savepath = os.path.join(self.logdir, filename) imageio.imwrite(savepath, contents) diff --git a/train.py b/train.py deleted file mode 100644 index fb3a87f..0000000 --- a/train.py +++ /dev/null @@ -1,175 +0,0 @@ -import tensorflow as tf -import tensorflow_addons as tfa -import argparse -from functools import partial - -from tf_raft.model import RAFT, SmallRAFT -from tf_raft.losses import sequence_loss, end_point_error -from tf_raft.datasets import MpiSintel, FlyingChairs, ShapeSetter, CropOrPadder -from tf_raft.training import VisFlowCallback, first_cycle_scaler - - -def train(args): - try: - sintel_dir = args.sintel_dir - chairs_dir = args.chairs_dir - chairs_split_txt = args.chairs_split_txt - - epochs = args.epochs - batch_size = args.batch_size - iters = args.iters - learning_rate = args.learning_rate - weight_decay = args.weight_decay - - crop_size = args.crop_size - min_scale = args.min_scale - max_scale = args.max_scale - do_flip = args.do_flip - - num_visualize = args.num_visualize - resume = args.resume - except ValueError: - print('invalid arguments are given') - - aug_params = { - 'crop_size': crop_size, - 'min_scale': min_scale, - 'max_scale': max_scale, - 'do_flip': do_flip - } - - ds_train_sintel = MpiSintel(aug_params, - split='training', - root=sintel_dir, - dstype='clean') - ds_train_chairs = FlyingChairs(aug_params, - split='training', - split_txt=chairs_split_txt, - root=chairs_dir) - ds_train = 20*ds_train_sintel + ds_train_chairs - ds_train.shuffle() - train_size = len(ds_train) - print(f'Found {train_size} samples for training') - - ds_val = MpiSintel(split='training', - root=sintel_dir, - dstype='clean') - val_size = len(ds_val) - print(f'Found {val_size} samples for validation') - - ds_test = MpiSintel(split='test', - root=sintel_dir, - dstype='clean') - - ds_train = tf.data.Dataset.from_generator( - ds_train, - output_types=(tf.uint8, tf.uint8, tf.float32, tf.bool), - ) - ds_train = ds_train.repeat(epochs)\ - .batch(batch_size)\ - .map(ShapeSetter(batch_size, crop_size))\ - .prefetch(buffer_size=1) - - ds_val = tf.data.Dataset.from_generator( - ds_val, - output_types=(tf.uint8, tf.uint8, tf.float32, tf.bool), - ) - ds_val = ds_val.batch(1)\ - .map(ShapeSetter(batch_size=1, image_size=(436, 1024)))\ - .map(CropOrPadder(target_size=(448, 1024)))\ - .prefetch(buffer_size=1) - - scheduler = tfa.optimizers.CyclicalLearningRate( - initial_learning_rate=learning_rate, - maximal_learning_rate=2*learning_rate, - step_size=1000, - scale_fn=first_cycle_scaler, - scale_mode='cycle', - ) - - optimizer = tfa.optimizers.AdamW( - weight_decay=weight_decay, - learning_rate=scheduler - ) - - raft = RAFT(drop_rate=0, iters=iters) - raft.compile( - optimizer=optimizer, - loss=sequence_loss, - epe=end_point_error - ) - - if resume: - print('Restoring pretrained weights ...', end=' ') - raft.load_weights(resume) - print('done') - - callbacks = [ - tf.keras.callbacks.TensorBoard(), - VisFlowCallback( - ds_test, - num_visualize=num_visualize, - choose_random=True - ), - tf.keras.callbacks.ModelCheckpoint( - filepath='checkpoints/model', - save_weights_only=True, - monitor='val_epe', - mode='min', - save_best_only=True - ) - ] - - raft.fit( - ds_train, - epochs=epochs, - callbacks=callbacks, - steps_per_epoch=train_size//batch_size, - validation_data=ds_val, - validation_steps=val_size - ) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser('Training RAFT') - parser.add_argument('-sd', '--sintel_dir', type=str, required=True, - help='Path to the MPI-Sintel dataset directory') - parser.add_argument('-cd', '--chairs_dir', type=str, required=True, - help='Path to the FlyingChairs dataset directory') - parser.add_argument('--chairs-split-txt', type=str, required=True, - help='Path to the FlyingChairs split textfile') - - parser.add_argument('-e', '--epochs', default=1, type=int, - help='Number of epochs [1]') - parser.add_argument('-bs', '--batch_size', default=4, type=int, - help='Batch size [4]') - parser.add_argument('-lr', '--learning_rate', default=1.2e-4, type=float, - help='Learning rate [1.2e-4]') - parser.add_argument('--weight_decay', default=1e-5, type=float, - help='Weight decay in optimizer [1e-5]') - - parser.add_argument('--iters', default=6, type=int, - help='Number of iterations in RAFT inference [6]') - - parser.add_argument('--crop_size', nargs=2, type=int, default=[320, 448], - help='Crop size for raw image [320, 448]') - parser.add_argument('--min_scale', default=-0.1, type=float, - help='Minimum scale in augmentation [-0.1]') - parser.add_argument('--max_scale', default=0.1, type=float, - help='Maximum scale in augmentation [0.1]') - parser.add_argument('--disable_flip', dest='do_flip', action='store_false', - help='Disable flip in augmentation [True]') - parser.set_defaults(do_flip=True) - - parser.add_argument('-nv', '--num-visualize', type=int, default=4, - help='Number of visualization per epoch [4]') - parser.add_argument('-r', '--resume', type=str, default=None, - help='Pretrained checkpoints [None]') - args = parser.parse_args() - - print('---- Config ----') - for k, v in vars(args).items(): - print(f'{k}: {v}') - print('------------------') - - train(args) diff --git a/train_chairs.py b/train_chairs.py new file mode 100644 index 0000000..df2e576 --- /dev/null +++ b/train_chairs.py @@ -0,0 +1,152 @@ +import os +import yaml +import argparse +import tensorflow as tf +import tensorflow_addons as tfa +from datetime import datetime + +from tf_raft.model import RAFT, SmallRAFT +from tf_raft.losses import sequence_loss, end_point_error +from tf_raft.datasets import FlyingChairs, ShapeSetter, CropOrPadder +from tf_raft.training import VisFlowCallback, first_cycle_scaler + + +def train(config, logdir): + try: + data_config = config['data'] + root = data_config['root'] + split_txt = data_config['split_txt'] + + aug_params = config['augment'] + crop_size = aug_params['crop_size'] + + model_config = config['model'] + iters = model_config['iters'] + iters_pred = model_config['iters_pred'] + resume = model_config['resume'] + + train_config = config['train'] + epochs = train_config['epochs'] + batch_size = train_config['batch_size'] + learning_rate = train_config['learning_rate'] + weight_decay = train_config['weight_decay'] + clip_norm = train_config['clip_norm'] + + vis_config = config['visualize'] + num_visualize = vis_config['num_visualize'] + choose_random = vis_config['choose_random'] + except ValueError: + print('invalid arguments are given') + + # training set + ds_train = FlyingChairs(aug_params, + split='training', + split_txt=split_txt, + root=root) + ds_train.shuffle() + train_size = len(ds_train) + print(f'Found {train_size} samples for training') + + ds_train = tf.data.Dataset.from_generator( + ds_train, + output_types=(tf.uint8, tf.uint8, tf.float32, tf.bool), + ) + ds_train = ds_train.repeat(epochs)\ + .batch(batch_size)\ + .map(ShapeSetter(batch_size, crop_size))\ + .prefetch(buffer_size=1) + + # validation set + ds_val = FlyingChairs(split='validation', + split_txt=split_txt, + root=root) + val_size = len(ds_val) + print(f'Found {val_size} samples for validation') + + ds_val = tf.data.Dataset.from_generator( + ds_val, + output_types=(tf.uint8, tf.uint8, tf.float32, tf.bool), + ) + ds_val = ds_val.batch(1)\ + .map(ShapeSetter(batch_size=1, image_size=(384, 512)))\ + .prefetch(buffer_size=1) + + # for visualization + ds_vis = FlyingChairs(split='validation', + split_txt=split_txt, + root=root) + + scheduler = tfa.optimizers.CyclicalLearningRate( + initial_learning_rate=learning_rate, + maximal_learning_rate=2*learning_rate, + step_size=1000, + scale_fn=first_cycle_scaler, + scale_mode='cycle', + ) + + optimizer = tfa.optimizers.AdamW( + weight_decay=weight_decay, + learning_rate=scheduler, + ) + + raft = RAFT(drop_rate=0, iters=iters, iters_pred=iters_pred) + raft.compile( + optimizer=optimizer, + clip_norm=clip_norm, + loss=sequence_loss, + epe=end_point_error + ) + + if resume: + print('Restoring pretrained weights ...', end=' ') + raft.load_weights(resume) + print('done') + + callbacks = [ + tf.keras.callbacks.TensorBoard(log_dir=logdir+'/history'), + VisFlowCallback( + ds_vis, + num_visualize=num_visualize, + choose_random=choose_random, + logdir=logdir+'/predicted_flows' + ), + tf.keras.callbacks.ModelCheckpoint( + filepath=logdir+'/checkpoints/model', + save_weights_only=True, + monitor='val_epe', + mode='min', + save_best_only=True + ) + ] + + raft.fit( + ds_train, + epochs=epochs, + callbacks=callbacks, + steps_per_epoch=train_size//batch_size, + validation_data=ds_val, + validation_steps=val_size + ) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Training RAFT') + parser.add_argument('config', type=str, help='Training config file') + args = parser.parse_args() + + with open(args.config, 'r') as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + + logd = config['logdir'] + logdir = os.path.join(logd, datetime.now().strftime("%Y-%m-%dT%H-%M")) + if not os.path.exists(logdir): + os.makedirs(logdir) + + savepath = logdir + "/config.yml" + with open(savepath, 'w') as f: + f.write(yaml.dump(config, default_flow_style=False)) + + print('------------ Config ---------------') + print(yaml.dump(config)) + print('-----------------------------------') + train(config, logdir) diff --git a/train_sintel.py b/train_sintel.py new file mode 100644 index 0000000..6979fe4 --- /dev/null +++ b/train_sintel.py @@ -0,0 +1,156 @@ +import os +import yaml +import argparse +import tensorflow as tf +import tensorflow_addons as tfa +from datetime import datetime + +from tf_raft.model import RAFT, SmallRAFT +from tf_raft.losses import sequence_loss, end_point_error +from tf_raft.datasets import MpiSintel, ShapeSetter, CropOrPadder +from tf_raft.training import VisFlowCallback, first_cycle_scaler + + +def train(config, logdir): + try: + data_config = config['data'] + root = data_config['root'] + + aug_params = config['augment'] + crop_size = aug_params['crop_size'] + + model_config = config['model'] + iters = model_config['iters'] + iters_pred = model_config['iters_pred'] + resume = model_config['resume'] + + train_config = config['train'] + epochs = train_config['epochs'] + batch_size = train_config['batch_size'] + learning_rate = train_config['learning_rate'] + weight_decay = train_config['weight_decay'] + clip_norm = train_config['clip_norm'] + + vis_config = config['visualize'] + num_visualize = vis_config['num_visualize'] + choose_random = vis_config['choose_random'] + except ValueError: + print('invalid arguments are given') + + # training set + ds_train = MpiSintel(aug_params, + split='training', + root=root, + dstype='clean') + ds_train.shuffle() + train_size = len(ds_train) + print(f'Found {train_size} samples for training') + + ds_train = tf.data.Dataset.from_generator( + ds_train, + output_types=(tf.uint8, tf.uint8, tf.float32, tf.bool), + ) + ds_train = ds_train.repeat(epochs)\ + .batch(batch_size)\ + .map(ShapeSetter(batch_size, crop_size))\ + .prefetch(buffer_size=1) + + # validation set + val_size = int(0.1*train_size) + ds_val = MpiSintel(split='training', + root=root, + dstype='clean') + ds_val.shuffle() + ds_val.image_list = ds_val.image_list[:val_size] + ds_val.flow_list = ds_val.flow_list[:val_size] + print(f'Found {val_size} samples for validation') + + ds_val = tf.data.Dataset.from_generator( + ds_val, + output_types=(tf.uint8, tf.uint8, tf.float32, tf.bool), + ) + ds_val = ds_val.batch(1)\ + .map(ShapeSetter(batch_size=1, image_size=(436, 1024)))\ + .map(CropOrPadder(target_size=(448, 1024)))\ + .prefetch(buffer_size=1) + + # for visualization + ds_vis = MpiSintel(split='training', + root=root, + dstype='clean') + ds_vis.shuffle() + + scheduler = tfa.optimizers.CyclicalLearningRate( + initial_learning_rate=learning_rate, + maximal_learning_rate=2*learning_rate, + step_size=1000, + scale_fn=first_cycle_scaler, + scale_mode='cycle', + ) + + optimizer = tfa.optimizers.AdamW( + weight_decay=weight_decay, + learning_rate=scheduler, + ) + + raft = RAFT(drop_rate=0, iters=iters, iters_pred=iters_pred) + raft.compile( + optimizer=optimizer, + clip_norm=clip_norm, + loss=sequence_loss, + epe=end_point_error + ) + + if resume: + print('Restoring pretrained weights ...', end=' ') + raft.load_weights(resume) + print('done') + + callbacks = [ + tf.keras.callbacks.TensorBoard(log_dir=logdir+'/history'), + VisFlowCallback( + ds_vis, + num_visualize=num_visualize, + choose_random=choose_random, + logdir=logdir+'/predicted_flows' + ), + tf.keras.callbacks.ModelCheckpoint( + filepath=logdir+'/checkpoints/model', + save_weights_only=True, + monitor='val_epe', + mode='min', + save_best_only=True + ) + ] + + raft.fit( + ds_train, + epochs=epochs, + callbacks=callbacks, + steps_per_epoch=train_size//batch_size, + validation_data=ds_val, + validation_steps=val_size + ) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Training RAFT') + parser.add_argument('config', type=str, help='Training config file') + args = parser.parse_args() + + with open(args.config, 'r') as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + + logd = config['logdir'] + logdir = os.path.join(logd, datetime.now().strftime("%Y-%m-%dT%H-%M")) + if not os.path.exists(logdir): + os.makedirs(logdir) + + savepath = logdir + "/config.yml" + with open(savepath, 'w') as f: + f.write(yaml.dump(config, default_flow_style=False)) + + print('------------ Config ---------------') + print(yaml.dump(config)) + print('-----------------------------------') + train(config, logdir)