Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deepest_conv_lstm_conv3d_pilotnet TF model #42

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions Formula1-FollowLine/tensorflow/DeepestConvLSTMConv3dPilotNet/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import time
import datetime
import argparse
import h5py

from tensorflow.python.keras.saving import hdf5_format
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, CSVLogger
from utils.processing import process_dataset
from utils.deepest_lstm_tinypilotnet import deepest_lstm_tinypilotnet_model
from utils.dataset import get_augmentations, DatasetSequence


def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument("--data_dir", action='append', help="Directory to find dataset")
parser.add_argument("--preprocess", action='append', default=None, help="preprocessing information: choose from crop/nocrop and normal/extreme")
# parser.add_argument("--base_dir", type=str, default='exp_random', help="Directory to save everything")
parser.add_argument("--data_augs", action='append', type=bool, default=None, help="Data Augmentations True/False")
parser.add_argument("--num_epochs", type=int, default=100, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
# parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for model training")
parser.add_argument("--img_shape", type=str, default=(200, 66, 3), help="Image shape")

args = parser.parse_args()
return args


if __name__ == "__main__":
args = parse_args()
path_to_data = args.data_dir[0]
preprocess = args.preprocess
data_augs = args.data_augs
num_epochs = args.num_epochs
batch_size = args.batch_size
# learning_rate = args.learning_rate
img_shape = tuple(map(int, args.img_shape.split(',')))

if 'no_crop' in preprocess:
type_image = 'no_crop'
else:
type_image = 'cropped'

if 'extreme' in preprocess:
data_type = 'extreme'
else:
data_type = 'no_extreme'

images_train, annotations_train, images_validation, annotations_validation = process_dataset(path_to_data, type_image, data_type, img_shape)

timestr = time.strftime("%Y%m%d-%H%M%S")
print(timestr)
print(images_train.shape)
print(annotations_train.shape)
print(images_validation.shape)
print(annotations_validation.shape)

hparams = {
'train_batch_size': 50,
'val_batch_size': 50,
'batch_size': 50,
'n_epochs': 300,
'checkpoint_dir': '../logs_test/'
}

print(hparams)

model_name = 'deepest_conv_lstm_conv3d_pilotnet_model'
model = deepest_conv_lstm_conv3d_pilotnet_model(img_shape)
model_filename = timestr + 'deepest_conv_lstm_conv3d_pilotnet_model'
model_file = model_filename + '.h5'

# Training data
train_gen = DatasetSequence(images_train, annotations_train, hparams['batch_size'],
augmentations=AUGMENTATIONS_TRAIN)

# Validation data
valid_gen = DatasetSequence(images_validation, annotations_validation, hparams['batch_size'],
augmentations=AUGMENTATIONS_TEST)

# Define callbacks
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
earlystopping = EarlyStopping(monitor="mae", patience=30, verbose=1, mode='auto')
# Create a callback that saves the model's weights
checkpoint_path = model_filename + '_cp.h5'
cp_callback = ModelCheckpoint(filepath=checkpoint_path, monitor='val_loss', save_best_only=True, verbose=1)
csv_logger = CSVLogger(model_filename + '.csv', append=True)

# Print layers
print(model)
model.build(img_shape)
print(model.summary())
# Training
model.fit(
train_gen,
epochs=hparams['n_epochs'],
verbose=2,
validation_data=valid_gen,
# workers=2, use_multiprocessing=False,
callbacks=[tensorboard_callback, earlystopping, cp_callback, csv_logger])

# Save the model
model.save(model_file)

# Evaluate the model
score = model.evaluate_generator(valid_gen, verbose=0)

print('Evaluating')
print('Test loss: ', score[0])
print('Test mean squared error: ', score[1])
print('Test mean absolute error: ', score[2])

# SAVE METADATA
from tensorflow.python.keras.saving import hdf5_format
import h5py

model_path = model_file
# Save model
with h5py.File(model_path, mode='w') as f:
hdf5_format.save_model_to_hdf5(model, f)
f.attrs['experiment_name'] = ''
f.attrs['experiment_description'] = ''
f.attrs['batch_size'] = hparams['train_batch_size']
f.attrs['nb_epoch'] = hparams['n_epochs']
f.attrs['model'] = model_name
f.attrs['img_shape'] = img_shape
f.attrs['normalized_dataset'] = True
f.attrs['sequences_dataset'] = True
f.attrs['gpu_trained'] = True
f.attrs['data_augmentation'] = True
f.attrs['extreme_data'] = False
f.attrs['split_test_train'] = 0.30
f.attrs['instances_number'] = len(array_annotations_train)
f.attrs['loss'] = score[0]
f.attrs['mse'] = score[1]
f.attrs['mae'] = score[2]
f.attrs['csv_path'] = model_filename + '.csv'
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import math

import numpy as np

from tensorflow.keras.utils import Sequence
from albumentations import (
Compose, HorizontalFlip, RandomBrightnessContrast,
HueSaturationValue, FancyPCA, RandomGamma, GaussNoise,
GaussianBlur, ToFloat, Normalize, ColorJitter, ChannelShuffle, Equalize
)


class DatasetSequence(Sequence):
def __init__(self, x_set, y_set, batch_size, augmentations):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
self.augment = augmentations

def __len__(self):
return math.ceil(len(self.x) / self.batch_size)

def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

new_batch = []
for x, img in enumerate(batch_x):
aug = self.augment(image=img[0])
augmented_0 = self.augment.replay(saved_augmentations=aug['replay'], image=img[0])["image"]
augmented_1 = self.augment.replay(saved_augmentations=aug['replay'], image=img[1])["image"]
augmented_2 = self.augment.replay(saved_augmentations=aug['replay'], image=img[2])["image"]
new_image = [augmented_0, augmented_1, augmented_2]
new_batch.append(np.array(new_image))
new_batch = np.array(new_batch)

return np.stack(new_batch, axis=0), np.array(batch_y)


def get_augmentations(data_augs):
if data_augs:
AUGMENTATIONS_TRAIN = ReplayCompose([
RandomBrightnessContrast(),
HueSaturationValue(),
FancyPCA(),
RandomGamma(),
GaussianBlur(),
# GaussNoise(),
#
# ColorJitter(),
# Equalize(),
# ChannelShuffle(),
#
# CoarseDropout(),
Normalize()
])
else:
AUGMENTATIONS_TRAIN = ReplayCompose([
Normalize()
])

AUGMENTATIONS_TEST = ReplayCompose([
Normalize()
])

return AUGMENTATIONS_TRAIN, AUGMENTATIONS_TEST

Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, TimeDistributed, Dense, LSTM, Conv3D, BatchNormalization, ConvLSTM2D, \
Dropout, Reshape, Activation, MaxPooling3D
from tensorflow.keras.optimizers import Adam


# DeepestConvLSTMConv3DPilotNet
def deepest_conv_lstm_conv3d_pilotnet_model(img_shape):
model = Sequential()
model.add(BatchNormalization(epsilon=0.001, axis=-1, input_shape=img_shape))

model.add(Conv3D(24, (5, 5, 5), strides=(2, 2, 2), activation="relu", padding='same'))
model.add(Conv3D(36, (5, 5, 5), strides=(2, 2, 2), activation="relu", padding='same'))
model.add(Conv3D(48, (5, 5, 5), strides=(2, 2, 2), activation="relu", padding='same'))
model.add(Conv3D(64, (3, 3, 3), strides=(1, 1, 1), activation="relu", padding='same'))
model.add(Conv3D(64, (3, 3, 3), strides=(1, 1, 1), activation="relu", padding='same'))

model.add(Reshape((1, 13, 64, 7)))

model.add(ConvLSTM2D(filters=8, kernel_size=(5, 5), padding="same", return_sequences=True))
model.add(ConvLSTM2D(filters=8, kernel_size=(5, 5), padding="same", return_sequences=True))
model.add(ConvLSTM2D(filters=8, kernel_size=(5, 5), padding="same", return_sequences=True))
model.add(ConvLSTM2D(filters=8, kernel_size=(5, 5), padding="same", return_sequences=False))
model.add(Flatten())
model.add(Dense(50, activation="relu"))
model.add(Dense(10, activation="relu"))
model.add(Dense(2))
adam = Adam(lr=0.00001)
model.compile(optimizer=adam, loss="mse", metrics=['mse', 'mae'])
return model
Loading