From 007ea4217dad0b9175df833f68a0fb4855151b0e Mon Sep 17 00:00:00 2001 From: baheytharwat Date: Sat, 8 Jun 2024 00:13:02 +0300 Subject: [PATCH 1/7] FIP VIT changes --- VIT/train_fip.py | 363 +++++++++++++++++++++++++++++++++++++++ VIT/utils/data_loader.py | 110 ++++++++++++ VIT/utils/data_utils.py | 62 +++++++ VIT/utils/dist_util.py | 30 ++++ VIT/utils/scheduler.py | 63 +++++++ 5 files changed, 628 insertions(+) create mode 100755 VIT/train_fip.py create mode 100644 VIT/utils/data_loader.py create mode 100755 VIT/utils/data_utils.py create mode 100755 VIT/utils/dist_util.py create mode 100755 VIT/utils/scheduler.py diff --git a/VIT/train_fip.py b/VIT/train_fip.py new file mode 100755 index 0000000..e5b3f36 --- /dev/null +++ b/VIT/train_fip.py @@ -0,0 +1,363 @@ +# coding=utf-8 +from __future__ import absolute_import, division, print_function + +import logging +import argparse +import os +import random +import numpy as np +import copy +from datetime import timedelta +import matplotlib.pyplot as plt +from tqdm import tqdm + +import torch +import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter + +from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule +from utils.data_loader import get_loader +from utils.dist_util import get_world_size + +from transformers import AutoModelForImageClassification + +logger = logging.getLogger(__name__) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def simple_accuracy(preds, labels): + return (preds == labels).mean() + + +def save_model(args, model, global_step=None): + model_to_save = model.module if hasattr(model, 'module') else model + if global_step is not None: + model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % str(global_step)) + else: + model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) + model.save_pretrained(model_checkpoint) + logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) + + +def setup(args): + + # Prepare model + if args.checkpoint: + model = AutoModelForImageClassification.from_pretrained(args.pretrained_dir, num_labels=args.num_classes) + else: + model = AutoModelForImageClassification.from_pretrained(args.model_type, num_labels=args.num_classes, + ignore_mismatched_sizes=True) + model.to(args.device) + num_params = count_parameters(model) + + logger.info("Training parameters %s", args) + logger.info("Total Parameter: \t%2.1fM" % num_params) + print(num_params) + return args, model + + +def count_parameters(model): + params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return params/1000000 + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def valid(args, model, writer, test_loader, global_step, val_txt = ""): + # Validation! + eval_losses = AverageMeter() + + logger.info("\n") + logger.info("***** Running Validation *****") + logger.info(" Num steps = %d", len(test_loader)) + logger.info(" Batch size = %d", args.eval_batch_size) + + model.eval() + all_preds, all_label = [], [] + epoch_iterator = tqdm(test_loader, + desc="Validating... (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + loss_fct = torch.nn.CrossEntropyLoss() + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + with torch.no_grad(): + logits = model(x).logits + eval_loss = loss_fct(logits, y) + eval_losses.update(eval_loss.item()) + + preds = torch.argmax(logits, dim=-1) + + if len(all_preds) == 0: + all_preds.append(preds.detach().cpu().numpy()) + all_label.append(y.detach().cpu().numpy()) + else: + all_preds[0] = np.append( + all_preds[0], preds.detach().cpu().numpy(), axis=0 + ) + all_label[0] = np.append( + all_label[0], y.detach().cpu().numpy(), axis=0 + ) + epoch_iterator.set_description("Validating %s" %val_txt+ "... (loss=%2.5f)" % eval_losses.val) + + all_preds, all_label = all_preds[0], all_label[0] + accuracy = simple_accuracy(all_preds, all_label) + + logger.info("\n") + logger.info("Validation Results: %s" % val_txt) + logger.info("Global Steps: %d" % global_step) + logger.info("Valid Loss: %2.5f" % eval_losses.avg) + logger.info("Valid Accuracy: %2.5f" % accuracy) + + writer.add_scalars('test/accuracy', {val_txt: accuracy}, global_step) + return accuracy + + +def train(args, model): + """ Train the model """ + if args.local_rank in [-1, 0]: + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(args.output_dir+"/logs", exist_ok=True) + + writer = SummaryWriter(log_dir=os.path.join(args.output_dir+"/logs", args.name)) + + args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps + + # Prepare dataset + train_loader, test_loader = get_loader(args, cmin = (args.task_n-1)*10, cmax = args.task_n*10, relabel = False) + old_loaders = [get_loader(args, cmin = i*10, cmax = (i+1)*10, relabel = False) for i in range(args.task_n)] + + # Prepare optimizer and scheduler + optimizer = torch.optim.SGD(model.parameters(), + lr=args.learning_rate, + momentum=0.9, + weight_decay=args.weight_decay) + t_total = args.num_steps + if args.decay_type == "cosine": + scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + else: + scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + + logger.info("\n") + model.zero_grad() + set_seed(args) # Added here for reproducibility (even between python 2 and 3) + losses = AverageMeter() + global_step, best_acc = 0, 0 + + accuracies = [[] for _ in range(args.task_n)] + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + accuracy = valid(args, model, writer, old_loaders[i][1], global_step, val_txt = txt) + accuracies[i].append(accuracy) + + if args.use_fip: + fip_batch_size = (args.task_n-1)*10*50 + batch_loader, _ = get_loader(args, cmin=0, cmax=(args.task_n-1)*10, relabel = False, data_size = fip_batch_size) + old_model = copy.deepcopy(model) + f_softmax = torch.nn.Softmax(dim=1) + batch_loader_iterator = iter(batch_loader) + + # Train! + logger.info("\n") + logger.info("***** Running training *****") + logger.info(" Total optimization steps = %d", args.num_steps) + logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * ( + torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + while True: + model.train() + epoch_iterator = tqdm(train_loader, + desc="Training (X / X Steps) (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + + + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + output = model(x).logits + loss_fct = torch.nn.CrossEntropyLoss() + CEloss = loss_fct(output, y) + + if args.use_fip: + try: + x_old, y_old = next(batch_loader_iterator) + except StopIteration: + batch_loader_iterator = iter(batch_loader) + x_old, y_old = next(batch_loader_iterator) + x_old = x_old.to(args.device) + + output_ori = f_softmax(old_model(x).logits) + output_d2 = f_softmax(model(x_old).logits) + output_newDest = f_softmax(old_model(x_old).logits) + output_ori = output_ori.detach() + output_newDest = output_newDest.detach() + epsAdd = max(1e-10, torch.min(output_ori)*1e-3) + loss1 = torch.sum(-torch.log(torch.sum(torch.sqrt(output*output_ori+epsAdd), axis=1))) + loss2 = torch.sum(-torch.log(torch.sum(torch.sqrt(output_d2*output_newDest+epsAdd), axis=1))) + else: + loss2 = 0 + loss = CEloss + loss2 + + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + loss.backward() + + if (step + 1) % args.gradient_accumulation_steps == 0: + losses.update(loss.item()*args.gradient_accumulation_steps) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + scheduler.step() + optimizer.step() + optimizer.zero_grad() + global_step += 1 + + epoch_iterator.set_description( + "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) + ) + if args.local_rank in [-1, 0]: + writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step) + writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step) + if global_step % args.eval_every == 0 and args.local_rank in [-1, 0]: + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + accuracy = valid(args, model, writer, old_loaders[i][1], global_step, val_txt = txt) + accuracies[i].append(accuracy) + + if best_acc < accuracy: + save_model(args, model) + best_acc = accuracy + model.train() + + if global_step % t_total == 0: + break + losses.reset() + if global_step % t_total == 0: + break + + if args.local_rank in [-1, 0]: + writer.close() + logger.info("Best Accuracy: \t%f" % best_acc) + logger.info("End Training!") + + + timesteps = [i for i in range(0, args.num_steps+1, args.eval_every)] + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + plt.plot(timesteps, accuracies[i], label=txt) + + plt.title("Accuracy plot") + plt.xlabel("Epochs") + plt.ylabel("Accuracy") + plt.savefig(os.path.join(args.output_dir, "accuracy_plot.png")) + +def main(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--name", required=True, + help="Name of this run. Used for monitoring.") + parser.add_argument("--num_classes", required=True, type=int, + help="Number of output classes") + parser.add_argument("--task_n", required=True, type=int, + help="Number of the task currently training") + parser.add_argument("--model_type", type=str, default="google/vit-base-patch16-224", + help="Which variant to use.") + parser.add_argument("--pretrained_dir", type=str, default="checkpoint/ViT-B_16.npz", + help="Where to search for pretrained ViT models.") + parser.add_argument("--output_dir", default="output", type=str, + help="The output directory where checkpoints will be written.") + parser.add_argument("--img_size", default=224, type=int, + help="Resolution size") + parser.add_argument("--train_batch_size", default=512, type=int, + help="Total batch size for training.") + parser.add_argument("--fip_batch_size", default=512, type=int, + help="Total batch size for fip training.") + parser.add_argument("--eval_batch_size", default=64, type=int, + help="Total batch size for eval.") + parser.add_argument("--eval_every", default=40, type=int, + help="Run prediction on validation set every so many steps." + "Will always run one evaluation at the end of training.") + + parser.add_argument("--learning_rate", default=3e-2, type=float, + help="The initial learning rate for SGD.") + parser.add_argument("--weight_decay", default=0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--num_steps", default=240, type=int, + help="Total number of training epochs to perform.") + parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", + help="How to decay the learning rate.") + parser.add_argument("--warmup_steps", default=120, type=int, + help="Step of training to perform learning rate warmup for.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + + parser.add_argument("--local_rank", type=int, default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument('--checkpoint', action='store_true', + help="Whether to use custom pre-trained model or imagenet21k pretrained") + parser.add_argument('--use_fip', action='store_true', + help="Whether to use fip training or normal training") + args = parser.parse_args() + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl', + timeout=timedelta(minutes=60)) + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s" % + (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1))) + + # Set seed + set_seed(args) + + # Model & Tokenizer Setup + args, model = setup(args) + + # Training + train(args, model) + + +if __name__ == "__main__": + main() diff --git a/VIT/utils/data_loader.py b/VIT/utils/data_loader.py new file mode 100644 index 0000000..e0447d2 --- /dev/null +++ b/VIT/utils/data_loader.py @@ -0,0 +1,110 @@ +import logging +import torch +from torchvision import transforms, datasets +from torch.utils.data import Dataset, DataLoader, RandomSampler, DistributedSampler, SequentialSampler +import pickle +from PIL import Image +import numpy as np +import os +import matplotlib.pyplot as plt + +logger = logging.getLogger(__name__) + +class CIFARDataset(Dataset): + + def __init__(self, data_dir, train = True, transform = None): + + self.data_file = data_dir +"/train" if train else data_dir+"/test" + self.meta_file = data_dir +"/meta" + self.transform = transform + + with open(self.data_file, "rb") as f: + data = pickle.load(f, encoding="latin1") + with open(self.meta_file, "rb") as f: + meta = pickle.load(f, encoding="latin1") + + self.images = data["data"] + self.images = np.vstack(self.images).reshape(-1, 3, 32, 32) + self.images = self.images.transpose((0, 2, 3, 1)) # convert to HWC + self.labels = np.array(data["fine_labels"]) + self.names = np.array(meta["fine_label_names"]) + + # self._filter_classes() + # self._visualize_image() + + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img = Image.fromarray(self.images[idx]) + label = torch.tensor(self.labels[idx]).long() + + if self.transform is not None: + img = self.transform(img) + # label = label.view(1) + return img, label + + def _filter_classes(self, min=0, max=10, relabel=True): + filtered = (self.labels >= min) & (self.labels < max) + self.images = self.images[filtered] + self.labels = self.labels[filtered] + if relabel: + self.labels -= min + self.names = self.names[min:max] + # print(self.names) + + def _filter_size(self, data_size): + self.images = self.images[:data_size] + self.labels = self.labels[:data_size] + + def _visualize_image(self): + idx = np.random.randint(len(self), size=1)[0] + img, label = self[idx] + class_name = self.names[int(label.item())] + print(idx, class_name) + plt.imshow(img) + plt.show() + +def get_loader(args, cmin = 0, cmax = 10, relabel = True, data_size = None): + + data_dir = "data/cifar-100-python" + + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() + + transform_train = transforms.Compose([ + transforms.RandomResizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + transform_test = transforms.Compose([ + transforms.Resize((args.img_size, args.img_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + trainset = CIFARDataset(data_dir, train = True, transform = transform_train) + trainset._filter_classes(cmin, cmax, relabel) + if data_size is not None: + trainset._filter_size(data_size) + testset = CIFARDataset(data_dir, train = False, transform = transform_test) + testset._filter_classes(cmin, cmax, relabel) + + if args.local_rank == 0: + torch.distributed.barrier() + + train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset) + test_sampler = SequentialSampler(testset) + train_loader = DataLoader(trainset, + sampler=train_sampler, + batch_size=args.train_batch_size, + num_workers=4, + pin_memory=True) + test_loader = DataLoader(testset, + sampler=test_sampler, + batch_size=args.eval_batch_size, + num_workers=4, + pin_memory=True) if testset is not None else None + + return train_loader, test_loader diff --git a/VIT/utils/data_utils.py b/VIT/utils/data_utils.py new file mode 100755 index 0000000..dc653f5 --- /dev/null +++ b/VIT/utils/data_utils.py @@ -0,0 +1,62 @@ +import logging + +import torch + +from torchvision import transforms, datasets +from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler + + +logger = logging.getLogger(__name__) + + +def get_loader(args): + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() + + transform_train = transforms.Compose([ + transforms.RandomResizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + transform_test = transforms.Compose([ + transforms.Resize((args.img_size, args.img_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + if args.dataset == "cifar10": + trainset = datasets.CIFAR10(root="./data", + train=True, + download=True, + transform=transform_train) + testset = datasets.CIFAR10(root="./data", + train=False, + download=True, + transform=transform_test) if args.local_rank in [-1, 0] else None + + else: + trainset = datasets.CIFAR100(root="./data", + train=True, + download=True, + transform=transform_train) + testset = datasets.CIFAR100(root="./data", + train=False, + download=True, + transform=transform_test) if args.local_rank in [-1, 0] else None + if args.local_rank == 0: + torch.distributed.barrier() + + train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset) + test_sampler = SequentialSampler(testset) + train_loader = DataLoader(trainset, + sampler=train_sampler, + batch_size=args.train_batch_size, + num_workers=4, + pin_memory=True) + test_loader = DataLoader(testset, + sampler=test_sampler, + batch_size=args.eval_batch_size, + num_workers=4, + pin_memory=True) if testset is not None else None + + return train_loader, test_loader diff --git a/VIT/utils/dist_util.py b/VIT/utils/dist_util.py new file mode 100755 index 0000000..ab8c447 --- /dev/null +++ b/VIT/utils/dist_util.py @@ -0,0 +1,30 @@ +import torch.distributed as dist + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + +def is_main_process(): + return get_rank() == 0 + +def format_step(step): + if isinstance(step, str): + return step + s = "" + if len(step) > 0: + s += "Training Epoch: {} ".format(step[0]) + if len(step) > 1: + s += "Training Iteration: {} ".format(step[1]) + if len(step) > 2: + s += "Validation Iteration: {} ".format(step[2]) + return s diff --git a/VIT/utils/scheduler.py b/VIT/utils/scheduler.py new file mode 100755 index 0000000..9daaf6e --- /dev/null +++ b/VIT/utils/scheduler.py @@ -0,0 +1,63 @@ +import logging +import math + +from torch.optim.lr_scheduler import LambdaLR + +logger = logging.getLogger(__name__) + +class ConstantLRSchedule(LambdaLR): + """ Constant learning rate schedule. + """ + def __init__(self, optimizer, last_epoch=-1): + super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) + + +class WarmupConstantSchedule(LambdaLR): + """ Linear warmup and then constant. + Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. + Keeps learning rate schedule equal to 1. after warmup_steps. + """ + def __init__(self, optimizer, warmup_steps, last_epoch=-1): + self.warmup_steps = warmup_steps + super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1.0, self.warmup_steps)) + return 1. + + +class WarmupLinearSchedule(LambdaLR): + """ Linear warmup and then linear decay. + Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. + Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. + """ + def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): + self.warmup_steps = warmup_steps + self.t_total = t_total + super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1, self.warmup_steps)) + return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) + + +class WarmupCosineSchedule(LambdaLR): + """ Linear warmup and then cosine decay. + Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. + Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. + If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. + """ + def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): + self.warmup_steps = warmup_steps + self.t_total = t_total + self.cycles = cycles + super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1.0, self.warmup_steps)) + # progress after warmup + progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) + return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) From 8e43915add382c20bacd55dfaa2c053365681a3c Mon Sep 17 00:00:00 2001 From: baheytharwat Date: Sat, 8 Jun 2024 03:26:08 +0300 Subject: [PATCH 2/7] save accuracy plot --- VIT/train_fip.py | 1 + 1 file changed, 1 insertion(+) diff --git a/VIT/train_fip.py b/VIT/train_fip.py index e5b3f36..7720976 100755 --- a/VIT/train_fip.py +++ b/VIT/train_fip.py @@ -276,6 +276,7 @@ def train(args, model): plt.title("Accuracy plot") plt.xlabel("Epochs") plt.ylabel("Accuracy") + plt.legend() plt.savefig(os.path.join(args.output_dir, "accuracy_plot.png")) def main(): From ae2ee39bc79a455c32c0bd796bb29f6c4b5b4220 Mon Sep 17 00:00:00 2001 From: baheytharwat Date: Sun, 9 Jun 2024 01:30:29 +0300 Subject: [PATCH 3/7] add readme --- VIT/README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 VIT/README.md diff --git a/VIT/README.md b/VIT/README.md new file mode 100644 index 0000000..293e578 --- /dev/null +++ b/VIT/README.md @@ -0,0 +1,17 @@ + +# Engineering flexible machine learning systems by traversing functionally-invariant paths + +Welcome to the official GitHub repository accompanying our latest research paper. Our work dives deep into building flexible machine learning systems that can be used in continual learning, sparsifying neural networks, or reducing adversarial attacks. + +## Getting Started + +This repository contains the codebase and sample notebooks to run the scripts. Please refer to the notebooks for installation instructions, usage guidelines, and insights into our methodology. + +### Notebooks +`FIP_VIT.ipynb`: [![Open in Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1UICCGijTeugrhhRSL4PQ0GOBM_mHcC70?usp=sharing): +In this Google Colab notebook, we will be running the continual learning script to learn new tasks with the reduction of catastrophic forgetting of the previous learnt tasks. + +--- + +For further details, kindly refer to our paper. If you have queries or suggestions, please open an issue or reach out to us directly. + From 2b0f24aaece015c9fc3e726bb24749434ca8d689 Mon Sep 17 00:00:00 2001 From: baheytharwat Date: Tue, 11 Jun 2024 00:19:09 +0300 Subject: [PATCH 4/7] add lora script --- VIT/train_lora.py | 363 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 363 insertions(+) create mode 100644 VIT/train_lora.py diff --git a/VIT/train_lora.py b/VIT/train_lora.py new file mode 100644 index 0000000..38dc3e3 --- /dev/null +++ b/VIT/train_lora.py @@ -0,0 +1,363 @@ +# coding=utf-8 +from __future__ import absolute_import, division, print_function + +import logging +import argparse +import os +import random +import numpy as np +import copy +from datetime import timedelta +import matplotlib.pyplot as plt +from tqdm import tqdm + +import torch +import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter + +from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule +from utils.data_loader import get_loader +from utils.dist_util import get_world_size + +from transformers import AutoModelForImageClassification +from peft import LoraConfig, get_peft_model + +logger = logging.getLogger(__name__) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def simple_accuracy(preds, labels): + return (preds == labels).mean() + + +def save_model(args, model, global_step=None): + model_to_save = model.module if hasattr(model, 'module') else model + if global_step is not None: + model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % str(global_step)) + else: + model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) + model.save_pretrained(model_checkpoint) + logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) + + +def setup(args): + # Prepare model + if args.checkpoint: + model = AutoModelForImageClassification.from_pretrained(args.pretrained_dir, num_labels=args.num_classes) + else: + model = AutoModelForImageClassification.from_pretrained(args.model_type, num_labels=args.num_classes, + ignore_mismatched_sizes=True) + + model.to(args.device) + num_params = count_parameters(model) + + # logger.info("{}".format(config)) + logger.info("Training parameters %s", args) + logger.info("Total Parameter: \t%2.1fM" % num_params) + print(num_params) + + return args, model + + +def count_parameters(model): + params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return params/1000000 + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def valid(args, model, writer, test_loader, global_step, val_txt = ""): + # Validation! + eval_losses = AverageMeter() + + logger.info("\n") + logger.info("***** Running Validation *****") + logger.info(" Num steps = %d", len(test_loader)) + logger.info(" Batch size = %d", args.eval_batch_size) + + model.eval() + all_preds, all_label = [], [] + epoch_iterator = tqdm(test_loader, + desc="Validating... (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + loss_fct = torch.nn.CrossEntropyLoss() + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + with torch.no_grad(): + logits = model(x).logits + eval_loss = loss_fct(logits, y) + eval_losses.update(eval_loss.item()) + + preds = torch.argmax(logits, dim=-1) + + if len(all_preds) == 0: + all_preds.append(preds.detach().cpu().numpy()) + all_label.append(y.detach().cpu().numpy()) + else: + all_preds[0] = np.append( + all_preds[0], preds.detach().cpu().numpy(), axis=0 + ) + all_label[0] = np.append( + all_label[0], y.detach().cpu().numpy(), axis=0 + ) + epoch_iterator.set_description("Validating %s" %val_txt+ "... (loss=%2.5f)" % eval_losses.val) + + all_preds, all_label = all_preds[0], all_label[0] + accuracy = simple_accuracy(all_preds, all_label) + + logger.info("\n") + logger.info("Validation Results: %s" % val_txt) + logger.info("Global Steps: %d" % global_step) + logger.info("Valid Loss: %2.5f" % eval_losses.avg) + logger.info("Valid Accuracy: %2.5f" % accuracy) + + writer.add_scalars('test/accuracy', {val_txt: accuracy}, global_step) + return accuracy + + +def train(args, model): + """ Train the model """ + if args.local_rank in [-1, 0]: + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(args.output_dir+"/logs", exist_ok=True) + + writer = SummaryWriter(log_dir=os.path.join(args.output_dir+"/logs", args.name)) + + args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps + + # Prepare dataset + train_loader, test_loader = get_loader(args, cmin = (args.task_n-1)*10, cmax = args.task_n*10, relabel = False) + old_loaders = [get_loader(args, cmin = i*10, cmax = (i+1)*10, relabel = False) for i in range(args.task_n)] + if args.use_lora: + config = LoraConfig( + r=args.lora_rank, + lora_alpha=args.local_rank*2, + # target_modules=["query", "value"], + target_modules=["query", "value", "classifier"], + lora_dropout=0.1, + bias="none", + # modules_to_save=["classifier"], + ) + lora_model = get_peft_model(model, config) + model.to(args.device) + num_params_lora = count_parameters(model) + logger.info("\n") + logger.info("Total Parameter (LoRA model): \t%2.1fM" % num_params_lora) + print(num_params_lora) + + # Prepare optimizer and scheduler + optimizer = torch.optim.SGD(model.parameters(), + lr=args.learning_rate, + momentum=0.9, + weight_decay=args.weight_decay) + t_total = args.num_steps + if args.decay_type == "cosine": + scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + else: + scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + + model.zero_grad() + set_seed(args) # Added here for reproducibility (even between python 2 and 3) + losses = AverageMeter() + global_step, best_acc = 0, 0 + + accuracies = [[] for _ in range(args.task_n)] + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + accuracy = valid(args, model, writer, old_loaders[i][1], global_step, val_txt = txt) + accuracies[i].append(accuracy) + + # Train! + logger.info("\n") + logger.info("***** Running training *****") + logger.info(" Total optimization steps = %d", args.num_steps) + logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * ( + torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + while True: + model.train() + epoch_iterator = tqdm(train_loader, + desc="Training (X / X Steps) (loss=X.X)", + bar_format="{l_bar}{r_bar}", + dynamic_ncols=True, + disable=args.local_rank not in [-1, 0]) + + + for step, batch in enumerate(epoch_iterator): + batch = tuple(t.to(args.device) for t in batch) + x, y = batch + output = model(x).logits + loss_fct = torch.nn.CrossEntropyLoss() + CEloss = loss_fct(output, y) + loss = CEloss + + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + loss.backward() + + if (step + 1) % args.gradient_accumulation_steps == 0: + losses.update(loss.item()*args.gradient_accumulation_steps) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + scheduler.step() + optimizer.step() + optimizer.zero_grad() + global_step += 1 + + epoch_iterator.set_description( + "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) + ) + if args.local_rank in [-1, 0]: + writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step) + writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step) + if global_step % args.eval_every == 0 and args.local_rank in [-1, 0]: + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + accuracy = valid(args, model, writer, old_loaders[i][1], global_step, val_txt = txt) + accuracies[i].append(accuracy) + + if best_acc < accuracy: + save_model(args, model) + best_acc = accuracy + model.train() + + if global_step % t_total == 0: + break + losses.reset() + if global_step % t_total == 0: + break + + if args.local_rank in [-1, 0]: + writer.close() + logger.info("Best Accuracy: \t%f" % best_acc) + logger.info("End Training!") + + + timesteps = [i for i in range(0, args.num_steps+1, args.eval_every)] + for i in range(args.task_n): + txt = "cifar_"+str(i*10)+":"+str((i+1)*10-1) + plt.plot(timesteps, accuracies[i], label=txt) + + plt.title("Accuracy plot") + plt.xlabel("Epochs") + plt.ylabel("Accuracy") + plt.legend() + plt.savefig(os.path.join(args.output_dir, "accuracy_plot.png")) + + +def main(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--name", required=True, + help="Name of this run. Used for monitoring.") + parser.add_argument("--num_classes", required=True, type=int, + help="Number of output classes") + parser.add_argument("--task_n", required=True, type=int, + help="Number of the task currently training") + parser.add_argument("--model_type", + default="ViT-B_16", + help="Which variant to use.") + parser.add_argument("--pretrained_dir", type=str, default="checkpoint/ViT-B_16.npz", + help="Where to search for pretrained ViT models.") + parser.add_argument("--output_dir", default="output", type=str, + help="The output directory where checkpoints will be written.") + parser.add_argument("--img_size", default=224, type=int, + help="Resolution size") + parser.add_argument("--train_batch_size", default=512, type=int, + help="Total batch size for training.") + parser.add_argument("--fip_batch_size", default=512, type=int, + help="Total batch size for fip training.") + parser.add_argument("--eval_batch_size", default=64, type=int, + help="Total batch size for eval.") + parser.add_argument("--eval_every", default=40, type=int, + help="Run prediction on validation set every so many steps." + "Will always run one evaluation at the end of training.") + + parser.add_argument("--learning_rate", default=3e-2, type=float, + help="The initial learning rate for SGD.") + parser.add_argument("--weight_decay", default=0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--num_steps", default=240, type=int, + help="Total number of training epochs to perform.") + parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", + help="How to decay the learning rate.") + parser.add_argument("--warmup_steps", default=120, type=int, + help="Step of training to perform learning rate warmup for.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + + parser.add_argument("--local_rank", type=int, default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument('--loss_scale', type=float, default=0, + help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" + "0 (default value): dynamic loss scaling.\n" + "Positive power of 2: static loss scaling value.\n") + parser.add_argument('--checkpoint', action='store_true', + help="Whether to use custom pre-trained model or imagenet21k pretrained") + parser.add_argument('--use_lora', action='store_true', + help="Whether to use lora training or normal training") + parser.add_argument("--lora_rank", default=4, type=int, + help="Rank of LoRA adapter.") + args = parser.parse_args() + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl', + timeout=timedelta(minutes=60)) + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" % + (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16)) + + # Set seed + set_seed(args) + + # Model & Tokenizer Setup + args, model = setup(args) + + # Training + train(args, model) + + +if __name__ == "__main__": + main() From b5c8f6f62cc7825704fa1ae1a5262a6a98cf9974 Mon Sep 17 00:00:00 2001 From: baheytharwat Date: Tue, 11 Jun 2024 01:17:59 +0300 Subject: [PATCH 5/7] minor changes --- VIT/train_lora.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/VIT/train_lora.py b/VIT/train_lora.py index 38dc3e3..52db294 100644 --- a/VIT/train_lora.py +++ b/VIT/train_lora.py @@ -311,17 +311,12 @@ def main(): help="Step of training to perform learning rate warmup for.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") - parser.add_argument('--loss_scale', type=float, default=0, - help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" - "0 (default value): dynamic loss scaling.\n" - "Positive power of 2: static loss scaling value.\n") parser.add_argument('--checkpoint', action='store_true', help="Whether to use custom pre-trained model or imagenet21k pretrained") parser.add_argument('--use_lora', action='store_true', @@ -346,8 +341,8 @@ def main(): logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) - logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" % - (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16)) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s" % + (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1))) # Set seed set_seed(args) From 01f417a53288b8d9720a48406b83de7740be81e4 Mon Sep 17 00:00:00 2001 From: baheytharwat Date: Wed, 12 Jun 2024 03:07:58 +0300 Subject: [PATCH 6/7] update readme --- VIT/README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/VIT/README.md b/VIT/README.md index 293e578..2f252fa 100644 --- a/VIT/README.md +++ b/VIT/README.md @@ -9,7 +9,10 @@ This repository contains the codebase and sample notebooks to run the scripts. P ### Notebooks `FIP_VIT.ipynb`: [![Open in Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1UICCGijTeugrhhRSL4PQ0GOBM_mHcC70?usp=sharing): -In this Google Colab notebook, we will be running the continual learning script to learn new tasks with the reduction of catastrophic forgetting of the previous learnt tasks. +In this Google Colab notebook, we will be running the continual learning script to learn new tasks with the reduction of catastrophic forgetting of the previous learnt tasks using FIP. + +`LoRA_VIT.ipynb`: [![Open in Google Colab](https://colab.research.google.com/drive/1IfQzwpYOh6Lr6fbljIXxpfxjhP6qQRkC?usp=sharing): +In this Google Colab notebook, we will be running the continual learning script to learn new tasks using LoRA adapter. --- From 0487ed0713cdbaf04142ed100cf722972cc4a12e Mon Sep 17 00:00:00 2001 From: baheytharwat Date: Wed, 12 Jun 2024 03:09:41 +0300 Subject: [PATCH 7/7] update readme --- VIT/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VIT/README.md b/VIT/README.md index 2f252fa..fab156a 100644 --- a/VIT/README.md +++ b/VIT/README.md @@ -11,7 +11,7 @@ This repository contains the codebase and sample notebooks to run the scripts. P `FIP_VIT.ipynb`: [![Open in Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1UICCGijTeugrhhRSL4PQ0GOBM_mHcC70?usp=sharing): In this Google Colab notebook, we will be running the continual learning script to learn new tasks with the reduction of catastrophic forgetting of the previous learnt tasks using FIP. -`LoRA_VIT.ipynb`: [![Open in Google Colab](https://colab.research.google.com/drive/1IfQzwpYOh6Lr6fbljIXxpfxjhP6qQRkC?usp=sharing): +`LoRA_VIT.ipynb`: [![Open in Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IfQzwpYOh6Lr6fbljIXxpfxjhP6qQRkC?usp=sharing): In this Google Colab notebook, we will be running the continual learning script to learn new tasks using LoRA adapter. ---