diff --git a/VIT/README.md b/VIT/README.md new file mode 100644 index 0000000..fab156a --- /dev/null +++ b/VIT/README.md @@ -0,0 +1,20 @@ + +# Engineering flexible machine learning systems by traversing functionally-invariant paths + +Welcome to the official GitHub repository accompanying our latest research paper. Our work dives deep into building flexible machine learning systems that can be used in continual learning, sparsifying neural networks, or reducing adversarial attacks. + +## Getting Started + +This repository contains the codebase and sample notebooks to run the scripts. Please refer to the notebooks for installation instructions, usage guidelines, and insights into our methodology. + +### Notebooks +`FIP_VIT.ipynb`: [![Open in Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1UICCGijTeugrhhRSL4PQ0GOBM_mHcC70?usp=sharing): +In this Google Colab notebook, we will be running the continual learning script to learn new tasks with the reduction of catastrophic forgetting of the previous learnt tasks using FIP. + +`LoRA_VIT.ipynb`: [![Open in Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IfQzwpYOh6Lr6fbljIXxpfxjhP6qQRkC?usp=sharing): +In this Google Colab notebook, we will be running the continual learning script to learn new tasks using LoRA adapter. + +--- + +For further details, kindly refer to our paper. If you have queries or suggestions, please open an issue or reach out to us directly. + diff --git a/VIT/train_fip.py b/VIT/train_fip.py new file mode 100755 index 0000000..7720976 --- /dev/null +++ b/VIT/train_fip.py @@ -0,0 +1,364 @@ +# coding=utf-8 +from __future__ import absolute_import, division, print_function + +import logging +import argparse +import os +import random +import numpy as np +import copy +from datetime import timedelta +import matplotlib.pyplot as plt +from tqdm import tqdm + +import torch +import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter + +from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule +from utils.data_loader import get_loader +from utils.dist_util import get_world_size + +from transformers import AutoModelForImageClassification + +logger = logging.getLogger(__name__) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def simple_accuracy(preds, labels): + return (preds == labels).mean() + + +def save_model(args, model, global_step=None): + model_to_save = model.module if hasattr(model, 'module') else model + if global_step is not None: + model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % str(global_step)) + else: + model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) + model.save_pretrained(model_checkpoint) + logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) + + +def setup(args): + + # Prepare model + if args.checkpoint: + model = AutoModelForImageClassification.from_pretrained(args.pretrained_dir, num_labels=args.num_classes) + else: + model = AutoModelForImageClassification.from_pretrained(args.model_type, num_labels=args.num_classes, + ignore_mismatched_sizes=True) + model.to(args.device) + num_params = count_parameters(model) + + logger.info("Training parameters %s", args) + logger.info("Total Parameter: \t%2.1fM" % num_params) + print(num_params) + return args, model + + +def count_parameters(model): + params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return params/1000000 + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def valid(args, model, writer, test_loader, global_step, val_txt = ""): + # Validation! + eval_losses = AverageMeter() + + logger.info("\n") + logger.info("***** Running Validation *****") + logger.info(" Num steps = %d", len(test_loader)) + logger.info(" Batch size = %d", args.eval_batch_size) + + model.eval() + all_preds, all_label = [], [] + epoch_iterator = tqdm(test_loader, + desc="Validating... (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + loss_fct = torch.nn.CrossEntropyLoss() + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + with torch.no_grad(): + logits = model(x).logits + eval_loss = loss_fct(logits, y) + eval_losses.update(eval_loss.item()) + + preds = torch.argmax(logits, dim=-1) + + if len(all_preds) == 0: + all_preds.append(preds.detach().cpu().numpy()) + all_label.append(y.detach().cpu().numpy()) + else: + all_preds[0] = np.append( + all_preds[0], preds.detach().cpu().numpy(), axis=0 + ) + all_label[0] = np.append( + all_label[0], y.detach().cpu().numpy(), axis=0 + ) + epoch_iterator.set_description("Validating %s" %val_txt+ "... (loss=%2.5f)" % eval_losses.val) + + all_preds, all_label = all_preds[0], all_label[0] + accuracy = simple_accuracy(all_preds, all_label) + + logger.info("\n") + logger.info("Validation Results: %s" % val_txt) + logger.info("Global Steps: %d" % global_step) + logger.info("Valid Loss: %2.5f" % eval_losses.avg) + logger.info("Valid Accuracy: %2.5f" % accuracy) + + writer.add_scalars('test/accuracy', {val_txt: accuracy}, global_step) + return accuracy + + +def train(args, model): + """ Train the model """ + if args.local_rank in [-1, 0]: + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(args.output_dir+"/logs", exist_ok=True) + + writer = SummaryWriter(log_dir=os.path.join(args.output_dir+"/logs", args.name)) + + args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps + + # Prepare dataset + train_loader, test_loader = get_loader(args, cmin = (args.task_n-1)*10, cmax = args.task_n*10, relabel = False) + old_loaders = [get_loader(args, cmin = i*10, cmax = (i+1)*10, relabel = False) for i in range(args.task_n)] + + # Prepare optimizer and scheduler + optimizer = torch.optim.SGD(model.parameters(), + lr=args.learning_rate, + momentum=0.9, + weight_decay=args.weight_decay) + t_total = args.num_steps + if args.decay_type == "cosine": + scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + else: + scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + + logger.info("\n") + model.zero_grad() + set_seed(args) # Added here for reproducibility (even between python 2 and 3) + losses = AverageMeter() + global_step, best_acc = 0, 0 + + accuracies = [[] for _ in range(args.task_n)] + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + accuracy = valid(args, model, writer, old_loaders[i][1], global_step, val_txt = txt) + accuracies[i].append(accuracy) + + if args.use_fip: + fip_batch_size = (args.task_n-1)*10*50 + batch_loader, _ = get_loader(args, cmin=0, cmax=(args.task_n-1)*10, relabel = False, data_size = fip_batch_size) + old_model = copy.deepcopy(model) + f_softmax = torch.nn.Softmax(dim=1) + batch_loader_iterator = iter(batch_loader) + + # Train! + logger.info("\n") + logger.info("***** Running training *****") + logger.info(" Total optimization steps = %d", args.num_steps) + logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * ( + torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + while True: + model.train() + epoch_iterator = tqdm(train_loader, + desc="Training (X / X Steps) (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + + + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + output = model(x).logits + loss_fct = torch.nn.CrossEntropyLoss() + CEloss = loss_fct(output, y) + + if args.use_fip: + try: + x_old, y_old = next(batch_loader_iterator) + except StopIteration: + batch_loader_iterator = iter(batch_loader) + x_old, y_old = next(batch_loader_iterator) + x_old = x_old.to(args.device) + + output_ori = f_softmax(old_model(x).logits) + output_d2 = f_softmax(model(x_old).logits) + output_newDest = f_softmax(old_model(x_old).logits) + output_ori = output_ori.detach() + output_newDest = output_newDest.detach() + epsAdd = max(1e-10, torch.min(output_ori)*1e-3) + loss1 = torch.sum(-torch.log(torch.sum(torch.sqrt(output*output_ori+epsAdd), axis=1))) + loss2 = torch.sum(-torch.log(torch.sum(torch.sqrt(output_d2*output_newDest+epsAdd), axis=1))) + else: + loss2 = 0 + loss = CEloss + loss2 + + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + loss.backward() + + if (step + 1) % args.gradient_accumulation_steps == 0: + losses.update(loss.item()*args.gradient_accumulation_steps) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + scheduler.step() + optimizer.step() + optimizer.zero_grad() + global_step += 1 + + epoch_iterator.set_description( + "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) + ) + if args.local_rank in [-1, 0]: + writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step) + writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step) + if global_step % args.eval_every == 0 and args.local_rank in [-1, 0]: + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + accuracy = valid(args, model, writer, old_loaders[i][1], global_step, val_txt = txt) + accuracies[i].append(accuracy) + + if best_acc < accuracy: + save_model(args, model) + best_acc = accuracy + model.train() + + if global_step % t_total == 0: + break + losses.reset() + if global_step % t_total == 0: + break + + if args.local_rank in [-1, 0]: + writer.close() + logger.info("Best Accuracy: \t%f" % best_acc) + logger.info("End Training!") + + + timesteps = [i for i in range(0, args.num_steps+1, args.eval_every)] + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + plt.plot(timesteps, accuracies[i], label=txt) + + plt.title("Accuracy plot") + plt.xlabel("Epochs") + plt.ylabel("Accuracy") + plt.legend() + plt.savefig(os.path.join(args.output_dir, "accuracy_plot.png")) + +def main(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--name", required=True, + help="Name of this run. Used for monitoring.") + parser.add_argument("--num_classes", required=True, type=int, + help="Number of output classes") + parser.add_argument("--task_n", required=True, type=int, + help="Number of the task currently training") + parser.add_argument("--model_type", type=str, default="google/vit-base-patch16-224", + help="Which variant to use.") + parser.add_argument("--pretrained_dir", type=str, default="checkpoint/ViT-B_16.npz", + help="Where to search for pretrained ViT models.") + parser.add_argument("--output_dir", default="output", type=str, + help="The output directory where checkpoints will be written.") + parser.add_argument("--img_size", default=224, type=int, + help="Resolution size") + parser.add_argument("--train_batch_size", default=512, type=int, + help="Total batch size for training.") + parser.add_argument("--fip_batch_size", default=512, type=int, + help="Total batch size for fip training.") + parser.add_argument("--eval_batch_size", default=64, type=int, + help="Total batch size for eval.") + parser.add_argument("--eval_every", default=40, type=int, + help="Run prediction on validation set every so many steps." + "Will always run one evaluation at the end of training.") + + parser.add_argument("--learning_rate", default=3e-2, type=float, + help="The initial learning rate for SGD.") + parser.add_argument("--weight_decay", default=0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--num_steps", default=240, type=int, + help="Total number of training epochs to perform.") + parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", + help="How to decay the learning rate.") + parser.add_argument("--warmup_steps", default=120, type=int, + help="Step of training to perform learning rate warmup for.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + + parser.add_argument("--local_rank", type=int, default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument('--checkpoint', action='store_true', + help="Whether to use custom pre-trained model or imagenet21k pretrained") + parser.add_argument('--use_fip', action='store_true', + help="Whether to use fip training or normal training") + args = parser.parse_args() + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl', + timeout=timedelta(minutes=60)) + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s" % + (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1))) + + # Set seed + set_seed(args) + + # Model & Tokenizer Setup + args, model = setup(args) + + # Training + train(args, model) + + +if __name__ == "__main__": + main() diff --git a/VIT/train_lora.py b/VIT/train_lora.py new file mode 100644 index 0000000..52db294 --- /dev/null +++ b/VIT/train_lora.py @@ -0,0 +1,358 @@ +# coding=utf-8 +from __future__ import absolute_import, division, print_function + +import logging +import argparse +import os +import random +import numpy as np +import copy +from datetime import timedelta +import matplotlib.pyplot as plt +from tqdm import tqdm + +import torch +import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter + +from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule +from utils.data_loader import get_loader +from utils.dist_util import get_world_size + +from transformers import AutoModelForImageClassification +from peft import LoraConfig, get_peft_model + +logger = logging.getLogger(__name__) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def simple_accuracy(preds, labels): + return (preds == labels).mean() + + +def save_model(args, model, global_step=None): + model_to_save = model.module if hasattr(model, 'module') else model + if global_step is not None: + model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % str(global_step)) + else: + model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) + model.save_pretrained(model_checkpoint) + logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) + + +def setup(args): + # Prepare model + if args.checkpoint: + model = AutoModelForImageClassification.from_pretrained(args.pretrained_dir, num_labels=args.num_classes) + else: + model = AutoModelForImageClassification.from_pretrained(args.model_type, num_labels=args.num_classes, + ignore_mismatched_sizes=True) + + model.to(args.device) + num_params = count_parameters(model) + + # logger.info("{}".format(config)) + logger.info("Training parameters %s", args) + logger.info("Total Parameter: \t%2.1fM" % num_params) + print(num_params) + + return args, model + + +def count_parameters(model): + params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return params/1000000 + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def valid(args, model, writer, test_loader, global_step, val_txt = ""): + # Validation! + eval_losses = AverageMeter() + + logger.info("\n") + logger.info("***** Running Validation *****") + logger.info(" Num steps = %d", len(test_loader)) + logger.info(" Batch size = %d", args.eval_batch_size) + + model.eval() + all_preds, all_label = [], [] + epoch_iterator = tqdm(test_loader, + desc="Validating... (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + loss_fct = torch.nn.CrossEntropyLoss() + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + with torch.no_grad(): + logits = model(x).logits + eval_loss = loss_fct(logits, y) + eval_losses.update(eval_loss.item()) + + preds = torch.argmax(logits, dim=-1) + + if len(all_preds) == 0: + all_preds.append(preds.detach().cpu().numpy()) + all_label.append(y.detach().cpu().numpy()) + else: + all_preds[0] = np.append( + all_preds[0], preds.detach().cpu().numpy(), axis=0 + ) + all_label[0] = np.append( + all_label[0], y.detach().cpu().numpy(), axis=0 + ) + epoch_iterator.set_description("Validating %s" %val_txt+ "... (loss=%2.5f)" % eval_losses.val) + + all_preds, all_label = all_preds[0], all_label[0] + accuracy = simple_accuracy(all_preds, all_label) + + logger.info("\n") + logger.info("Validation Results: %s" % val_txt) + logger.info("Global Steps: %d" % global_step) + logger.info("Valid Loss: %2.5f" % eval_losses.avg) + logger.info("Valid Accuracy: %2.5f" % accuracy) + + writer.add_scalars('test/accuracy', {val_txt: accuracy}, global_step) + return accuracy + + +def train(args, model): + """ Train the model """ + if args.local_rank in [-1, 0]: + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(args.output_dir+"/logs", exist_ok=True) + + writer = SummaryWriter(log_dir=os.path.join(args.output_dir+"/logs", args.name)) + + args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps + + # Prepare dataset + train_loader, test_loader = get_loader(args, cmin = (args.task_n-1)*10, cmax = args.task_n*10, relabel = False) + old_loaders = [get_loader(args, cmin = i*10, cmax = (i+1)*10, relabel = False) for i in range(args.task_n)] + if args.use_lora: + config = LoraConfig( + r=args.lora_rank, + lora_alpha=args.local_rank*2, + # target_modules=["query", "value"], + target_modules=["query", "value", "classifier"], + lora_dropout=0.1, + bias="none", + # modules_to_save=["classifier"], + ) + lora_model = get_peft_model(model, config) + model.to(args.device) + num_params_lora = count_parameters(model) + logger.info("\n") + logger.info("Total Parameter (LoRA model): \t%2.1fM" % num_params_lora) + print(num_params_lora) + + # Prepare optimizer and scheduler + optimizer = torch.optim.SGD(model.parameters(), + lr=args.learning_rate, + momentum=0.9, + weight_decay=args.weight_decay) + t_total = args.num_steps + if args.decay_type == "cosine": + scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + else: + scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + + model.zero_grad() + set_seed(args) # Added here for reproducibility (even between python 2 and 3) + losses = AverageMeter() + global_step, best_acc = 0, 0 + + accuracies = [[] for _ in range(args.task_n)] + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + accuracy = valid(args, model, writer, old_loaders[i][1], global_step, val_txt = txt) + accuracies[i].append(accuracy) + + # Train! + logger.info("\n") + logger.info("***** Running training *****") + logger.info(" Total optimization steps = %d", args.num_steps) + logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * ( + torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + while True: + model.train() + epoch_iterator = tqdm(train_loader, + desc="Training (X / X Steps) (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + + + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + output = model(x).logits + loss_fct = torch.nn.CrossEntropyLoss() + CEloss = loss_fct(output, y) + loss = CEloss + + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + loss.backward() + + if (step + 1) % args.gradient_accumulation_steps == 0: + losses.update(loss.item()*args.gradient_accumulation_steps) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + scheduler.step() + optimizer.step() + optimizer.zero_grad() + global_step += 1 + + epoch_iterator.set_description( + "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) + ) + if args.local_rank in [-1, 0]: + writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step) + writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step) + if global_step % args.eval_every == 0 and args.local_rank in [-1, 0]: + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + accuracy = valid(args, model, writer, old_loaders[i][1], global_step, val_txt = txt) + accuracies[i].append(accuracy) + + if best_acc < accuracy: + save_model(args, model) + best_acc = accuracy + model.train() + + if global_step % t_total == 0: + break + losses.reset() + if global_step % t_total == 0: + break + + if args.local_rank in [-1, 0]: + writer.close() + logger.info("Best Accuracy: \t%f" % best_acc) + logger.info("End Training!") + + + timesteps = [i for i in range(0, args.num_steps+1, args.eval_every)] + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + plt.plot(timesteps, accuracies[i], label=txt) + + plt.title("Accuracy plot") + plt.xlabel("Epochs") + plt.ylabel("Accuracy") + plt.legend() + plt.savefig(os.path.join(args.output_dir, "accuracy_plot.png")) + + +def main(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--name", required=True, + help="Name of this run. Used for monitoring.") + parser.add_argument("--num_classes", required=True, type=int, + help="Number of output classes") + parser.add_argument("--task_n", required=True, type=int, + help="Number of the task currently training") + parser.add_argument("--model_type", + default="ViT-B_16", + help="Which variant to use.") + parser.add_argument("--pretrained_dir", type=str, default="checkpoint/ViT-B_16.npz", + help="Where to search for pretrained ViT models.") + parser.add_argument("--output_dir", default="output", type=str, + help="The output directory where checkpoints will be written.") + parser.add_argument("--img_size", default=224, type=int, + help="Resolution size") + parser.add_argument("--train_batch_size", default=512, type=int, + help="Total batch size for training.") + parser.add_argument("--fip_batch_size", default=512, type=int, + help="Total batch size for fip training.") + parser.add_argument("--eval_batch_size", default=64, type=int, + help="Total batch size for eval.") + parser.add_argument("--eval_every", default=40, type=int, + help="Run prediction on validation set every so many steps." + "Will always run one evaluation at the end of training.") + + parser.add_argument("--learning_rate", default=3e-2, type=float, + help="The initial learning rate for SGD.") + parser.add_argument("--weight_decay", default=0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--num_steps", default=240, type=int, + help="Total number of training epochs to perform.") + parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", + help="How to decay the learning rate.") + parser.add_argument("--warmup_steps", default=120, type=int, + help="Step of training to perform learning rate warmup for.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + parser.add_argument("--local_rank", type=int, default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument('--checkpoint', action='store_true', + help="Whether to use custom pre-trained model or imagenet21k pretrained") + parser.add_argument('--use_lora', action='store_true', + help="Whether to use lora training or normal training") + parser.add_argument("--lora_rank", default=4, type=int, + help="Rank of LoRA adapter.") + args = parser.parse_args() + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl', + timeout=timedelta(minutes=60)) + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s" % + (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1))) + + # Set seed + set_seed(args) + + # Model & Tokenizer Setup + args, model = setup(args) + + # Training + train(args, model) + + +if __name__ == "__main__": + main() diff --git a/VIT/utils/data_loader.py b/VIT/utils/data_loader.py new file mode 100644 index 0000000..e0447d2 --- /dev/null +++ b/VIT/utils/data_loader.py @@ -0,0 +1,110 @@ +import logging +import torch +from torchvision import transforms, datasets +from torch.utils.data import Dataset, DataLoader, RandomSampler, DistributedSampler, SequentialSampler +import pickle +from PIL import Image +import numpy as np +import os +import matplotlib.pyplot as plt + +logger = logging.getLogger(__name__) + +class CIFARDataset(Dataset): + + def __init__(self, data_dir, train = True, transform = None): + + self.data_file = data_dir +"/train" if train else data_dir+"/test" + self.meta_file = data_dir +"/meta" + self.transform = transform + + with open(self.data_file, "rb") as f: + data = pickle.load(f, encoding="latin1") + with open(self.meta_file, "rb") as f: + meta = pickle.load(f, encoding="latin1") + + self.images = data["data"] + self.images = np.vstack(self.images).reshape(-1, 3, 32, 32) + self.images = self.images.transpose((0, 2, 3, 1)) # convert to HWC + self.labels = np.array(data["fine_labels"]) + self.names = np.array(meta["fine_label_names"]) + + # self._filter_classes() + # self._visualize_image() + + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img = Image.fromarray(self.images[idx]) + label = torch.tensor(self.labels[idx]).long() + + if self.transform is not None: + img = self.transform(img) + # label = label.view(1) + return img, label + + def _filter_classes(self, min=0, max=10, relabel=True): + filtered = (self.labels >= min) & (self.labels < max) + self.images = self.images[filtered] + self.labels = self.labels[filtered] + if relabel: + self.labels -= min + self.names = self.names[min:max] + # print(self.names) + + def _filter_size(self, data_size): + self.images = self.images[:data_size] + self.labels = self.labels[:data_size] + + def _visualize_image(self): + idx = np.random.randint(len(self), size=1)[0] + img, label = self[idx] + class_name = self.names[int(label.item())] + print(idx, class_name) + plt.imshow(img) + plt.show() + +def get_loader(args, cmin = 0, cmax = 10, relabel = True, data_size = None): + + data_dir = "data/cifar-100-python" + + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() + + transform_train = transforms.Compose([ + transforms.RandomResizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + transform_test = transforms.Compose([ + transforms.Resize((args.img_size, args.img_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + trainset = CIFARDataset(data_dir, train = True, transform = transform_train) + trainset._filter_classes(cmin, cmax, relabel) + if data_size is not None: + trainset._filter_size(data_size) + testset = CIFARDataset(data_dir, train = False, transform = transform_test) + testset._filter_classes(cmin, cmax, relabel) + + if args.local_rank == 0: + torch.distributed.barrier() + + train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset) + test_sampler = SequentialSampler(testset) + train_loader = DataLoader(trainset, + sampler=train_sampler, + batch_size=args.train_batch_size, + num_workers=4, + pin_memory=True) + test_loader = DataLoader(testset, + sampler=test_sampler, + batch_size=args.eval_batch_size, + num_workers=4, + pin_memory=True) if testset is not None else None + + return train_loader, test_loader diff --git a/VIT/utils/data_utils.py b/VIT/utils/data_utils.py new file mode 100755 index 0000000..dc653f5 --- /dev/null +++ b/VIT/utils/data_utils.py @@ -0,0 +1,62 @@ +import logging + +import torch + +from torchvision import transforms, datasets +from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler + + +logger = logging.getLogger(__name__) + + +def get_loader(args): + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() + + transform_train = transforms.Compose([ + transforms.RandomResizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + transform_test = transforms.Compose([ + transforms.Resize((args.img_size, args.img_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + if args.dataset == "cifar10": + trainset = datasets.CIFAR10(root="./data", + train=True, + download=True, + transform=transform_train) + testset = datasets.CIFAR10(root="./data", + train=False, + download=True, + transform=transform_test) if args.local_rank in [-1, 0] else None + + else: + trainset = datasets.CIFAR100(root="./data", + train=True, + download=True, + transform=transform_train) + testset = datasets.CIFAR100(root="./data", + train=False, + download=True, + transform=transform_test) if args.local_rank in [-1, 0] else None + if args.local_rank == 0: + torch.distributed.barrier() + + train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset) + test_sampler = SequentialSampler(testset) + train_loader = DataLoader(trainset, + sampler=train_sampler, + batch_size=args.train_batch_size, + num_workers=4, + pin_memory=True) + test_loader = DataLoader(testset, + sampler=test_sampler, + batch_size=args.eval_batch_size, + num_workers=4, + pin_memory=True) if testset is not None else None + + return train_loader, test_loader diff --git a/VIT/utils/dist_util.py b/VIT/utils/dist_util.py new file mode 100755 index 0000000..ab8c447 --- /dev/null +++ b/VIT/utils/dist_util.py @@ -0,0 +1,30 @@ +import torch.distributed as dist + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + +def is_main_process(): + return get_rank() == 0 + +def format_step(step): + if isinstance(step, str): + return step + s = "" + if len(step) > 0: + s += "Training Epoch: {} ".format(step[0]) + if len(step) > 1: + s += "Training Iteration: {} ".format(step[1]) + if len(step) > 2: + s += "Validation Iteration: {} ".format(step[2]) + return s diff --git a/VIT/utils/scheduler.py b/VIT/utils/scheduler.py new file mode 100755 index 0000000..9daaf6e --- /dev/null +++ b/VIT/utils/scheduler.py @@ -0,0 +1,63 @@ +import logging +import math + +from torch.optim.lr_scheduler import LambdaLR + +logger = logging.getLogger(__name__) + +class ConstantLRSchedule(LambdaLR): + """ Constant learning rate schedule. + """ + def __init__(self, optimizer, last_epoch=-1): + super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) + + +class WarmupConstantSchedule(LambdaLR): + """ Linear warmup and then constant. + Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. + Keeps learning rate schedule equal to 1. after warmup_steps. + """ + def __init__(self, optimizer, warmup_steps, last_epoch=-1): + self.warmup_steps = warmup_steps + super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1.0, self.warmup_steps)) + return 1. + + +class WarmupLinearSchedule(LambdaLR): + """ Linear warmup and then linear decay. + Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. + Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. + """ + def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): + self.warmup_steps = warmup_steps + self.t_total = t_total + super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1, self.warmup_steps)) + return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) + + +class WarmupCosineSchedule(LambdaLR): + """ Linear warmup and then cosine decay. + Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. + Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. + If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. + """ + def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): + self.warmup_steps = warmup_steps + self.t_total = t_total + self.cycles = cycles + super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1.0, self.warmup_steps)) + # progress after warmup + progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) + return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))