Skip to content

Commit

Permalink
Merge pull request #22 from daigo0927/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
daigo0927 authored Oct 3, 2020
2 parents 722a8db + 2086a8f commit 3c85f54
Show file tree
Hide file tree
Showing 36 changed files with 446 additions and 220 deletions.
60 changes: 45 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,70 @@ see details in `pyoroject.toml`
from tf_raft.model import RAFT, SmallRAFT
from tf_raft.losses import sequence_loss, end_point_error

# iters means number of recurrent update of flow
raft = RAFT(iters=iters)
# iters/iters_pred are the number of recurrent update of flow in training/prediction
raft = RAFT(iters=iters, iters_pred=iters_pred)
raft.compile(
optimizer=optimizer,
clip_norm=clip_norm,
loss=sequence_loss,
epe=end_point_error
)

raft.fit(
dataset,
ds_train,
epochs=epochs,
callbacks=callbacks,
steps_per_epoch=train_size//batch_size,
validation_data=ds_val,
validation_steps=val_size
)
```

In practice, you are required to prepare dataset, optimizer, callbacks etc, check details in `train.py`.
In practice, you are required to prepare dataset, optimizer, callbacks etc, check details in `train_sintel.py` or `train_chairs.py`.

## Load the pretrained weights
### Train via YAML configuration

You can download the pretrained weights via `gsutil` or `curl` (trained on MPI-Sintel Clean, and FlyingChairs)
`train_chairs.py` and `train_sintel.py` train RAFT model via YAML configuration. Sample configs are in `configs` directory. Run;

``` shell
$ gsutil cp -r gs://tf-raft-pretrained/checkpoints .
$ python train_chairs.py /path/to/config.yml
```

## Pre-trained models

I made the pre-trained weights (on both FlyingChairs and MPI-Sintel) public.
You can download them via `gsutil` or `curl`.

### Trained weights on FlyingChairs

``` shell
$ gsutil cp -r gs://tf-raft-pretrained/2020-09-26T18-38/checkpoints .
```
or
``` shell
$ mkdir checkpoints
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T18-38/checkpoints/model.data-00000-of-00001
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T18-38/checkpoints/model.index
$ mv model* checkpoints/
```

### Trained weights on MPI-Sintel (Clean path)

``` shell
$ gsutil cp -r gs://tf-raft-pretrained/2020-09-26T08-51/checkpoints .
```
or
``` shell
$ mkdir checkpoints
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/checkpoints/model.data-00000-of-00001
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/checkpoints/model.index
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T08-51/checkpoints/model.data-00000-of-00001
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T08-51/checkpoints/model.index
$ mv model* checkpoints/
```

then
### Load weights

``` python
raft = RAFT(iters=iters)
raft = RAFT(iters=iters, iters_pred=iters_pred)
raft.load_weights('checkpoints/model')

# forward (with dummy inputs)
Expand All @@ -76,11 +104,13 @@ print(flow_predictions[-1].shape) # >> (1, 448, 512, 2)
```

## Note
Though I have tried to reproduce the original implementation faithfully, there is some difference between it and my implementation (mainly because of used framework: PyTorch/TensorFlow);
Though I have tried to reproduce the original implementation faithfully, there is some difference between the original one and mine (mainly because of used framework: PyTorch/TensorFlow);

- The original implementations provides cuda-based correlation function but I don't. My TF-based implementation works well, but cuda-based one may run faster.
- I have trained my model on FlyingChairs and MPI-Sintel separately in my private environment (GCP with P100 accelerator). The model has been trained well, but not reached the best score reported in the paper (trained on multiple datasets).
- The original one uses mixed-precision. This may get training much faster, but I don't. TensorFlow also enables mixed-precision with few additional lines, see https://www.tensorflow.org/guide/mixed_precision if interested.

- The original implements cuda-based correlation function but I don't. My TF-based implementation works well, but cuda-based one may runs faster.
- I have trained my model only on MPI-Sintel dataset in my private environment (GCP with P100 accelerator). The model has been trained well, but not reached the best score reported in the paper (trained on multiple datasets).
- The original uses mixed-precision. This may get traininig much faster, but I don't. TensorFlow also enables mixed-precision with few additional lines, see https://www.tensorflow.org/guide/mixed_precision if interested.
Additional, global gradient clipping seems to be essential for stable training though it is not emphasized in the original paper. This operation can be done via `torch.nn.utils.clip_grad_norm_(model.parameters(), clip)` in PyTorch, `tf.clip_by_global_norm(grads, clip_norm)` in TF (coded at `self.train_step` in `tf_raft/model.py`).

## References
- https://github.com/princeton-vl/RAFT
Expand Down
27 changes: 27 additions & 0 deletions configs/train_chairs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
data:
root: './FlyingChairs_release/data'
split_txt: './FlyingChairs_release/FlyingChairs_train_val.txt'

augment:
crop_size: [368, 496]
min_scale: -0.1
max_scale: 1.0
do_flip: True

model:
iters: 12
iters_pred: 24
resume: 0 # '/path/to/checkpoint'

train:
epochs: 1
batch_size: 4
learning_rate: 0.0004
weight_decay: 0.0001
clip_norm: 1

visualize:
num_visualize: 4
choose_random: True

logdir: './logs'
26 changes: 26 additions & 0 deletions configs/train_sintel.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
data:
root: './MPI-Sintel-complete'

augment:
crop_size: [368, 768]
min_scale: -0.2
max_scale: 0.6
do_flip: True

model:
iters: 12
iters_pred: 24
resume: 0 # '/path/to/checkpoint'

train:
epochs: 1
batch_size: 4
learning_rate: 0.00012
weight_decay: 0.00001
clip_norm: 1

visualize:
num_visualize: 4
choose_random: True

logdir: './logs'
Binary file removed predicted_flows/epoch001_bamboo_3_17.png
Binary file not shown.
Binary file removed predicted_flows/epoch001_cave_3_37.png
Binary file not shown.
Binary file removed predicted_flows/epoch001_market_1_42.png
Binary file not shown.
Binary file removed predicted_flows/epoch001_temple_1_0.png
Binary file not shown.
Binary file removed predicted_flows/epoch002_PERTURBED_market_3_26.png
Binary file not shown.
Binary file removed predicted_flows/epoch002_PERTURBED_shaman_1_2.png
Binary file not shown.
Binary file removed predicted_flows/epoch002_mountain_2_10.png
Binary file not shown.
Binary file removed predicted_flows/epoch002_wall_18.png
Binary file not shown.
Binary file removed predicted_flows/epoch003_PERTURBED_market_3_24.png
Binary file not shown.
Binary file removed predicted_flows/epoch003_bamboo_3_25.png
Binary file not shown.
Binary file removed predicted_flows/epoch003_bamboo_3_31.png
Binary file not shown.
Binary file removed predicted_flows/epoch003_market_1_11.png
Binary file not shown.
Binary file removed predicted_flows/epoch004_PERTURBED_market_3_7.png
Binary file not shown.
Binary file removed predicted_flows/epoch004_cave_3_12.png
Binary file not shown.
Binary file removed predicted_flows/epoch004_market_4_25.png
Binary file not shown.
Binary file removed predicted_flows/epoch004_wall_30.png
Binary file not shown.
Binary file removed predicted_flows/epoch005_bamboo_3_36.png
Binary file not shown.
Binary file removed predicted_flows/epoch005_market_1_31.png
Binary file not shown.
Binary file removed predicted_flows/epoch005_market_4_45.png
Binary file not shown.
Binary file removed predicted_flows/epoch005_mountain_2_2.png
Binary file not shown.
Binary file added samples_sintel/epoch030_204.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added samples_sintel/epoch030_362.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added samples_sintel/epoch030_899.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added samples_sintel/epoch030_988.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion tests/losses/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_end_point_error(data):
u3_valid = 3 / 8
u5_valid = 5 / 8

info = end_point_error([flow_gt, valid], predictions)
info = end_point_error([flow_gt, valid], predictions[-1])

# 2 decimal for sqrt precision
np.testing.assert_almost_equal(info['epe'], epe_valid, decimal=2)
Expand Down
10 changes: 6 additions & 4 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,35 @@ def test_upsample_flow():

def test_raft():
iters = 6
iters_pred = 12
image1 = tf.random.normal((BATCH_SIZE, *IMAGE_SIZE, 3))
image2 = tf.random.normal((BATCH_SIZE, *IMAGE_SIZE, 3))

model = RAFT(drop_rate=0.0, iters=iters)
model = RAFT(drop_rate=0.0, iters=iters, iters_pred=iters_pred)
output = model([image1, image2], training=True)
assert len(output) == iters
for flow in output:
assert flow.shape == (BATCH_SIZE, *IMAGE_SIZE, 2)

output = model([image1, image2], training=False)
assert len(output) == iters
assert len(output) == iters_pred
for flow in output:
assert flow.shape == (BATCH_SIZE, *IMAGE_SIZE, 2)


def test_small_raft():
iters = 6
iters_pred = 12
image1 = tf.random.normal((BATCH_SIZE, *IMAGE_SIZE, 3))
image2 = tf.random.normal((BATCH_SIZE, *IMAGE_SIZE, 3))

model = SmallRAFT(drop_rate=0.0, iters=iters)
model = SmallRAFT(drop_rate=0.0, iters=iters, iters_pred=iters_pred)
output = model([image1, image2], training=True)
assert len(output) == iters
for flow in output:
assert flow.shape == (BATCH_SIZE, *IMAGE_SIZE, 2)

output = model([image1, image2], training=False)
assert len(output) == iters
assert len(output) == iters_pred
for flow in output:
assert flow.shape == (BATCH_SIZE, *IMAGE_SIZE, 2)
29 changes: 17 additions & 12 deletions tf_raft/datasets/augmentor.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import numpy as np
import random
import math
from PIL import Image

import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

# import torch
# from torchvision.transforms import ColorJitter
# import torch.nn.functional as F
import albumentations as A
from PIL import Image


class FlowAugmentor:
Expand All @@ -37,7 +30,7 @@ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
contrast_limit=0.4
),
A.HueSaturationValue(
hue_shift_limit=int(0.5/3.14*255),
hue_shift_limit=int(0.5/3.14*180),
sat_shift_limit=int(0.4*255),
val_shift_limit=int(0.)
)
Expand Down Expand Up @@ -135,6 +128,7 @@ def __call__(self, img1, img2, flow):

return img1, img2, flow


class SparseFlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
# spatial augmentation params
Expand All @@ -151,13 +145,25 @@ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
self.v_flip_prob = 0.1

# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
# self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.photo_aug = A.Compose([
A.RandomBrightnessContrast(
brightness_limit=0.3,
contrast_limit=0.3
),
A.HueSaturationValue(
hue_shift_limit=int(0.3/3.14*180),
sat_shift_limit=int(0.3*255),
val_shift_limit=int(0.)
)
])
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5

def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
# image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
image_stack = self.photo_aug(image=image_stack)['image']
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2

Expand Down Expand Up @@ -248,7 +254,6 @@ def spatial_transform(self, img1, img2, flow, valid):
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow, valid


def __call__(self, img1, img2, flow, valid):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
Expand Down
2 changes: 1 addition & 1 deletion tf_raft/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def end_point_error(y_true, y_pred, max_flow=400):
mag = tf.sqrt(tf.reduce_sum(flow_gt**2, axis=-1))
valid = valid & (mag < max_flow)

epe = tf.sqrt(tf.reduce_sum((y_pred[-1] - flow_gt)**2, axis=-1))
epe = tf.sqrt(tf.reduce_sum((y_pred - flow_gt)**2, axis=-1))
epe = epe[valid]
epe_under1 = tf.cast(epe < 1, dtype=tf.float32)
epe_under3 = tf.cast(epe < 3, dtype=tf.float32)
Expand Down
23 changes: 13 additions & 10 deletions tf_raft/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class RAFT(tf.keras.Model):
def __init__(self, drop_rate=0, iters=12, **kwargs):
def __init__(self, drop_rate=0, iters=12, iters_pred=24, **kwargs):
super().__init__(**kwargs)

self.hidden_dim = hdim = 128
Expand All @@ -19,6 +19,7 @@ def __init__(self, drop_rate=0, iters=12, **kwargs):
self.drop_rate = drop_rate

self.iters = iters
self.iters_pred = iters_pred

self.fnet = BasicEncoder(output_dim=256,
norm_type='instance',
Expand Down Expand Up @@ -88,7 +89,8 @@ def call(self, inputs, training):
coords0, coords1 = self.initialize_flow(image1)

flow_predictions = []
for i in range(self.iters):
iters = self.iters if training else self.iters_pred
for i in range(iters):
# (bs, h, w, 81xnum_levels)
corr = correlation.retrieve(coords1)

Expand All @@ -106,9 +108,10 @@ def call(self, inputs, training):
# flow_predictions[-1] is the finest output
return flow_predictions

def compile(self, optimizer, loss, epe, **kwargs):
def compile(self, optimizer, clip_norm, loss, epe, **kwargs):
super().compile(**kwargs)
self.optimizer = optimizer
self.clip_norm = clip_norm
self.loss = loss
self.epe = epe

Expand All @@ -129,9 +132,10 @@ def train_step(self, data):
flow_predictions = self([image1, image2], training=True)
loss = self.loss([flow, valid], flow_predictions)
grads = tape.gradient(loss, self.trainable_weights)
grads, _ = tf.clip_by_global_norm(grads, self.clip_norm)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

info = self.epe([flow, valid], flow_predictions)
info = self.epe([flow, valid], flow_predictions[-1])
self.flow_metrics['loss'].update_state(loss)
self.flow_metrics['epe'].update_state(info['epe'])
self.flow_metrics['u1'].update_state(info['u1'])
Expand All @@ -145,10 +149,8 @@ def test_step(self, data):
image2 = tf.cast(image2, dtype=tf.float32)

flow_predictions = self([image1, image2], training=False)
loss = self.loss([flow, valid], flow_predictions)

info = self.epe([flow, valid], flow_predictions)
self.flow_metrics['loss'].update_state(loss)
info = self.epe([flow, valid], flow_predictions[-1])
self.flow_metrics['epe'].update_state(info['epe'])
self.flow_metrics['u1'].update_state(info['u1'])
self.flow_metrics['u3'].update_state(info['u3'])
Expand All @@ -169,8 +171,8 @@ def reset_metrics(self):


class SmallRAFT(RAFT):
def __init__(self, drop_rate=0, iters=12, **kwargs):
super().__init__(drop_rate, iters, **kwargs)
def __init__(self, drop_rate=0, iters=12, iters_pred=24, **kwargs):
super().__init__(drop_rate, iters, iters_pred, **kwargs)

self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
Expand Down Expand Up @@ -207,7 +209,8 @@ def call(self, inputs, training):
coords0, coords1 = self.initialize_flow(image1)

flow_predictions = []
for i in range(self.iters):
iters = self.iters if training else self.iters_pred
for i in range(iters):
corr = correlation.retrieve(coords1)

flow = coords1 - coords0
Expand Down
4 changes: 2 additions & 2 deletions tf_raft/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def on_epoch_end(self, epoch, logs=None):
vis_ids = range(self.num_visualize)

for i in vis_ids:
image1, image2, (scene, index) = self.dataset[i]
image1, image2, *_ = self.dataset[i]
if len(image1.shape) > 3:
raise ValueError('target dataset must not be batched')

Expand All @@ -83,6 +83,6 @@ def on_epoch_end(self, epoch, logs=None):

contents = np.concatenate([image1, image2, flow_img], axis=0)

filename = f'epoch{str(epoch+1).zfill(3)}_{scene}_{index}.png'
filename = f'epoch{str(epoch+1).zfill(3)}_{str(i+1).zfill(3)}.png'
savepath = os.path.join(self.logdir, filename)
imageio.imwrite(savepath, contents)
Loading

0 comments on commit 3c85f54

Please sign in to comment.