diff --git a/src/imitation_learning/config.json b/src/imitation_learning/config.json new file mode 100644 index 0000000..e4b0fe2 --- /dev/null +++ b/src/imitation_learning/config.json @@ -0,0 +1,63 @@ +{ + "backbone": "resnet18", + "backbone_pretrained": true, + "num_outputs": 2, + "max_epochs": 100, + "optimizer": "adam", + "learning_rate": 0.0001, + "loss": "mse", + "num_workers": 8, + "batch_size": 256, + "val_frequency": 1, + "save_frequency": 5, + "run_name": "resnet18_first_attempt", + "outdir": "local/data/ackermann_plus", + "train_datasets": [ + "local/data/Datasets_Ackermann/montmelo_ackermann_follow_lane", + "local/data/Datasets_Ackermann/many_curves_ackermann_follow_lane" + ], + "val_datasets": [ + "local/data/Datasets_Ackermann/montreal_ackermann_follow_lane" + ], + "test_datasets": [ + "local/data/Datasets_Ackermann/nurburgring_ackermann_follow_lane" + ], + "preprocessing": [ + "flip", + "cropped", + "extreme" + ], + "augmentations": { + "ColorJitter": { + "brightness": 0.2, + "contrast": 0.2, + "saturation": 0.2, + "hue": 0.1, + "p": 0.5 + }, + "GaussianBlur": { + "kernel_size": 3, + "sigma": [ + 0.1, + 2.0 + ], + "p": 0.5 + }, + "GaussNoise": { + "mean": 0, + "sigma": 0.1, + "p": 0.5 + }, + "RandomErasing": { + "scale": [ + 0.02, + 0.33 + ], + "ratio": [ + 0.3, + 3.3 + ], + "p": 0.5 + } + } +} \ No newline at end of file diff --git a/src/imitation_learning/test_offline.py b/src/imitation_learning/test_offline.py new file mode 100644 index 0000000..42d60d6 --- /dev/null +++ b/src/imitation_learning/test_offline.py @@ -0,0 +1,133 @@ +import argparse +import csv +import os +from time import time + +import cv2 +import matplotlib.pyplot as plt +import torch +from torchvision.transforms import v2 as transforms +from torchvision.models import resnet18 +from tqdm import tqdm + +from dl_car_control.Ackermann.utils.pilotnet import PilotNet + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--test", type=str) + parser.add_argument("--model", type=str) + parser.add_argument("--model_type", type=str) + parser.add_argument("--cropped", action="store_true") + + args = parser.parse_args() + return args + + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + all_w_gt = [] + all_w_pred = [] + all_v_gt = [] + all_v_pred = [] + all_t = [] + + args = parse_args() + num_labels = 2 + image_shape = (66, 200, 3) + + if args.model_type == "pilotnet": + model = PilotNet(image_shape, num_labels) + elif args.model_type == "resnet18": + model = resnet18() + model.fc = torch.nn.Linear(model.fc.in_features, num_labels) + else: + raise ValueError(f"Invalid model type {args.model_type}") + + model.load_state_dict(torch.load(args.model, weights_only=True, map_location=device)) + model.to(device) + model.eval() + + preprocess = transforms.Compose( + [ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + ] + ) + + f = open(os.path.join(args.test, "data.csv"), "r") + reader_csv = csv.reader(f) + next(reader_csv, None) # skip the headers + samples = [sample for sample in reader_csv] + num_samples = len(samples) + + total_time = 0 + min_dt = 20000 + max_dt = -1 + + total_loss_v = 0 + total_loss_w = 0 + pbar = tqdm(enumerate(samples), total=num_samples, leave=True) + for idx, (image_name, v_gt, w_gt) in pbar: + start_time = time() + + image = cv2.imread(os.path.join(args.test, image_name))[:, :, ::-1] + if args.cropped: + image = image[240:480, 0:640] + resized_image = cv2.resize(image, (image_shape[1], image_shape[0])) + + input_tensor = preprocess(resized_image).to(device) + input_batch = input_tensor.unsqueeze(0) + + with torch.no_grad(): + output = model(input_batch) + v_pred, w_pred = output.squeeze().cpu().numpy() + + all_v_gt.append(float(v_gt)) + all_w_gt.append(float(w_gt)) + + all_v_pred.append(v_pred) + all_w_pred.append(w_pred) + + total_loss_v = total_loss_v + abs(float(v_gt) - v_pred) + total_loss_w = total_loss_w + abs(float(w_gt) - w_pred) + + all_t.append(idx) + + finish_time = time() + dt = finish_time - start_time + total_time = total_time + dt + if dt < min_dt: + min_dt = dt + if dt > max_dt: + max_dt = dt + + f.close() + + print("Avg. dt:" + str(total_time / num_samples)) + print("Min. dt:" + str(min_dt)) + print("Max. dt:" + str(max_dt)) + print("Avg. W abs(diff):" + str(total_loss_w / num_samples)) + print("Avg. V abs(diff):" + str(total_loss_v / num_samples)) + + plt.subplot(1, 2, 1) + plt.plot(all_t, all_v_gt, label="controller", color="b") + plt.plot(all_t, all_v_pred, label="net", color="tab:orange") + plt.title("Linear speed comparison") + plt.xlabel("Samples") + plt.ylabel("Linear speed output") + plt.legend(loc="upper left") + plt.subplot(1, 2, 2) + plt.plot(all_t, all_w_gt, label="controller", color="b") + plt.plot(all_t, all_w_pred, label="net", color="tab:orange") + plt.title("Angular speed comparison") + plt.xlabel("Samples") + plt.ylabel("Angular speed output") + plt.legend(loc="upper left") + plt.show() + + +if __name__ == "__main__": + main() diff --git a/src/imitation_learning/test_online.py b/src/imitation_learning/test_online.py new file mode 100644 index 0000000..8e9828a --- /dev/null +++ b/src/imitation_learning/test_online.py @@ -0,0 +1,80 @@ +import argparse + +import cv2 +import torch +from torchvision.transforms import v2 as transforms +from torchvision.models import resnet18 + +import dl_car_control.Ackermann.utils.hal as HAL +from dl_car_control.Ackermann.utils.pilotnet import PilotNet + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--model", type=str) + parser.add_argument("--model_type", type=str) + parser.add_argument("--cropped", action="store_true") + + args = parser.parse_args() + return args + + +args = parse_args() +num_labels = 2 +image_shape = (66, 200, 3) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +if args.model_type == "pilotnet": + model = PilotNet(image_shape, num_labels) +elif args.model_type == "resnet18": + model = resnet18() + model.fc = torch.nn.Linear(model.fc.in_features, num_labels) +else: + raise ValueError(f"Invalid model type {args.model_type}") + +model.load_state_dict(torch.load(args.model, weights_only=True, map_location=device)) +model.to(device) +model.eval() + + +def user_main(): + image = HAL.getImage() + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + height = image.shape[0] + + preprocess = transforms.Compose( + [ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + ] + ) + + # crop image + if height > 100: + if args.cropped: + image = image[240:480, 0:640] + + resized_image = cv2.resize(image, (image_shape[1], image_shape[0])) + + input_tensor = preprocess(resized_image).to(device) + input_batch = input_tensor.unsqueeze(0) + + with torch.no_grad(): + output = model(input_batch) + v, w = output.squeeze().cpu().numpy() + + HAL.setV(v) + HAL.setW(w) + + +def main(): + + HAL.setW(0) + HAL.setV(0) + HAL.main(user_main) + + +if __name__ == "__main__": + main() diff --git a/src/imitation_learning/train.py b/src/imitation_learning/train.py new file mode 100644 index 0000000..a6aeaeb --- /dev/null +++ b/src/imitation_learning/train.py @@ -0,0 +1,253 @@ +import argparse +from datetime import datetime +import json +import os + +import torch +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from torchvision.models import ( + resnet18, + ResNet18_Weights, + efficientnet_v2_s, + EfficientNet_V2_S_Weights, + vgg11, + VGG11_Weights, + mobilenet_v3_small, + MobileNet_V3_Small_Weights, + vit_b_16, + ViT_B_16_Weights, + swin_v2_s, + Swin_V2_S_Weights, +) +from torchvision.transforms import v2 as transforms +from tqdm import tqdm + +from dl_car_control.Ackermann.utils.pilot_net_dataset import PilotNetDataset + + +def parse_args() -> argparse.Namespace: + """Parse user input arguments + + :return: parsed arguments + :rtype: argparse.Namespace + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--cfg", + type=str, + default="dl_car_control/ackermann_plus/config.json", + help="Path to JSON config file", + ) + return parser.parse_args() + + +def main(): + """Main function""" + args = parse_args() + + cfg = json.load(open(args.cfg, "r")) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + outdir = os.path.join(cfg["outdir"], f"{cfg['run_name']}-{timestamp}") + os.makedirs(outdir, exist_ok=True) + + # Init model and adapt last layer + if cfg["backbone"] == "resnet18": + model = resnet18(weights=ResNet18_Weights.DEFAULT) + model.fc = torch.nn.Linear(model.fc.in_features, cfg["num_outputs"]) + elif cfg["backbone"] == "efficientnet_v2_s": + model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.DEFAULT) + model.classifier[-1] = torch.nn.Linear( + model.classifier[-1].in_features, cfg["num_outputs"] + ) + elif cfg["backbone"] == "vgg11": + model = vgg11(weights=VGG11_Weights.DEFAULT) + model.classifier[-1] = torch.nn.Linear( + model.classifier[-1].in_features, cfg["num_outputs"] + ) + elif cfg["backbone"] == "mobilenet_v3_small": + model = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT) + model.classifier[-1] = torch.nn.Linear( + model.classifier[-1].in_features, cfg["num_outputs"] + ) + elif cfg["backbone"] == "vit_b_16": + model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT) + model.heads.head = torch.nn.Linear( + model.heads.head.in_features, cfg["num_outputs"] + ) + elif cfg["backbone"] == "swin_v2_s": + model = swin_v2_s(weights=Swin_V2_S_Weights.DEFAULT) + model.head = torch.nn.Linear(model.head.in_features, cfg["num_outputs"]) + else: + raise ValueError(f"Invalid backbone {cfg['backbone']}") + model.to(device) + + # Init transforms + all_transforms = [transforms.ToImage()] + if "ColorJitter" in cfg["augmentations"]: + all_transforms.append( + transforms.RandomApply( + [ + transforms.ColorJitter( + brightness=cfg["augmentations"]["ColorJitter"]["brightness"], + contrast=cfg["augmentations"]["ColorJitter"]["contrast"], + saturation=cfg["augmentations"]["ColorJitter"]["saturation"], + hue=cfg["augmentations"]["ColorJitter"]["hue"], + ) + ], + p=cfg["augmentations"]["ColorJitter"]["p"], + ) + ) + if "GaussianBlur" in cfg["augmentations"]: + all_transforms.append( + transforms.RandomApply( + [ + transforms.GaussianBlur( + kernel_size=cfg["augmentations"]["GaussianBlur"]["kernel_size"], + sigma=cfg["augmentations"]["GaussianBlur"]["sigma"], + ) + ], + p=cfg["augmentations"]["GaussianBlur"]["p"], + ) + ) + if "GaussianNoise" in cfg["augmentations"]: + all_transforms.append( + transforms.RandomApply( + [ + transforms.GaussianNoise( + mean=cfg["augmentations"]["GaussianNoise"]["mean"], + sigma=cfg["augmentations"]["GaussianNoise"]["sigma"], + ) + ], + p=cfg["augmentations"]["GaussianNoise"]["p"], + ) + ) + if "RandomErasing" in cfg["augmentations"]: + transforms.RandomErasing( + p=cfg["augmentations"]["RandomErasing"]["p"], + scale=cfg["augmentations"]["RandomErasing"]["scale"], + ratio=cfg["augmentations"]["RandomErasing"]["ratio"], + ) + + all_transforms.append(transforms.ToDtype(torch.float32, scale=True)) + all_transforms = transforms.Compose(all_transforms) + + # Init datasets and dataloaders + train_dataset = PilotNetDataset( + cfg["train_datasets"], + flip_images="flip" in cfg["preprocessing"], + transforms=all_transforms, + preprocessing=cfg["preprocessing"], + ) + + if "val_datasets" in cfg: + val_dataset = PilotNetDataset( + cfg["val_datasets"], + flip_images=False, + transforms=transforms.Compose( + [ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + ] + ), + preprocessing=None, + ) + elif "val_split" in cfg: + num_samples = len(train_dataset) + num_samples_val = round(num_samples * cfg["val_split"]) + num_samples_train = num_samples - num_samples_val + train_dataset, val_dataset = torch.utils.data.random_split( + train_dataset, [num_samples_train, num_samples_val] + ) + else: + raise Exception("No validation dataset or split provided") + + train_dataloader = DataLoader( + train_dataset, + batch_size=cfg["batch_size"], + num_workers=cfg["num_workers"], + shuffle=True, + ) + val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False) + + # Init optimizer and loss function + if cfg["optimizer"] == "adam": + optimizer = torch.optim.Adam(model.parameters(), lr=cfg["learning_rate"]) + else: + raise ValueError(f"Invalid optimizer {cfg['optimizer']}") + + if cfg["loss"] == "mse": + loss_fn = torch.nn.MSELoss() + else: + raise ValueError(f"Invalid loss {cfg['loss']}") + + # Init tensorboard summary + writer = SummaryWriter(outdir) + writer.add_text("Configuration", json.dumps(cfg, indent=4)) + + # Train model + best_val_loss = 10**10 + for epoch_idx in range(cfg["max_epochs"]): + model.train() + running_loss = 0.0 + pbar = tqdm(train_dataloader, leave=True) + for batch_idx, (inputs, labels) in enumerate(pbar): + optimizer.zero_grad() + outputs = model(inputs.to(device)) + loss = loss_fn(outputs, labels.to(device)) + loss.backward() + optimizer.step() + + running_loss += loss.cpu().item() + current_loss = running_loss / (batch_idx + 1) + + pbar.set_description(f"TRAIN-Epoch [{epoch_idx + 1}/{cfg['max_epochs']}]") + pbar.set_postfix(loss=current_loss) + + writer.add_scalar(f"Training loss ({cfg['loss']})", current_loss, epoch_idx) + + save_model = epoch_idx % cfg["save_frequency"] == 0 + + if epoch_idx % cfg["val_frequency"] == 0: + model.eval() + running_val_loss = 0.0 + pbar = tqdm(val_dataloader, leave=True) + for batch_idx, (inputs, labels) in enumerate(pbar): + with torch.no_grad(): + outputs = model(inputs.to(device)) + loss = loss_fn(outputs, labels.to(device)) + + running_val_loss += loss.cpu().item() + current_val_loss = running_val_loss / (batch_idx + 1) + + pbar.set_description( + f"VAL-Epoch [{epoch_idx + 1}/{cfg['max_epochs']}]" + ) + pbar.set_postfix(loss=current_val_loss) + + writer.add_scalar( + f"Validation loss ({cfg['loss']})", current_val_loss, epoch_idx + ) + print(f"Validation loss: {current_val_loss}\n") + + if current_val_loss < best_val_loss: + best_val_loss = current_val_loss + print(f"New best validation loss: {best_val_loss}") + save_model = True + + if save_model: + fname = os.path.join( + outdir, f"model-epoch_{epoch_idx:03d}-loss_{current_val_loss:.3f}.pth" + ) + print(f"Saving model to {fname}") + torch.save(model.state_dict(), fname) + + writer.flush() + + writer.close() + + +if __name__ == "__main__": + main()