From 04755bc956667df19011d27b0a6bbb6812313cb3 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Wed, 21 Jun 2023 04:00:17 +0000 Subject: [PATCH 1/2] add back train_num_examples --- pipeline/train/instruction_following.py | 1 + pipeline/train/instruction_following_ds.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pipeline/train/instruction_following.py b/pipeline/train/instruction_following.py index dc394202..02e02d44 100644 --- a/pipeline/train/instruction_following.py +++ b/pipeline/train/instruction_following.py @@ -276,6 +276,7 @@ def main(): parser.add_argument("--weight_decay", default=0.1, type=float) # data args parser.add_argument("--workers", type=int, default=4) + parser.add_argument("--train_num_samples", type=int, default=None) # distributed training args parser.add_argument( "--dist-url", diff --git a/pipeline/train/instruction_following_ds.py b/pipeline/train/instruction_following_ds.py index e17654bb..ef435771 100644 --- a/pipeline/train/instruction_following_ds.py +++ b/pipeline/train/instruction_following_ds.py @@ -234,6 +234,7 @@ def main(): parser.add_argument("--weight_decay", default=0.1, type=float) # data args parser.add_argument("--workers", type=int, default=4) + parser.add_argument("--train_num_samples", type=int, default=None) # distributed training args parser.add_argument( "--dist-url", From 927be4bd9ee8f9b9ec5526547986c79766611d12 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Wed, 21 Jun 2023 04:06:54 +0000 Subject: [PATCH 2/2] add back to ddp file --- pipeline/train/instruction_following_ddp.py | 520 ++++++++++++++++++++ 1 file changed, 520 insertions(+) create mode 100644 pipeline/train/instruction_following_ddp.py diff --git a/pipeline/train/instruction_following_ddp.py b/pipeline/train/instruction_following_ddp.py new file mode 100644 index 00000000..a78d0f38 --- /dev/null +++ b/pipeline/train/instruction_following_ddp.py @@ -0,0 +1,520 @@ +""" Main training script """ + +import argparse +import glob +import os +import random + +import numpy as np +import torch +import torch.nn +import wandb +from pipeline.train.data import get_data +from pipeline.train.distributed import init_distributed_device, world_info_from_env +from transformers import ( + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_linear_schedule_with_warmup, + CLIPImageProcessor, +) + +from pipeline.train.train_utils import ( + AverageMeter, + get_autocast, + get_cast_dtype, + get_checkpoint, +) + +from flamingo.modeling_flamingo import FlamingoForConditionalGeneration +from flamingo.configuration_flamingo import FlamingoConfig +from otter.modeling_otter import OtterForConditionalGeneration +from otter.configuration_otter import OtterConfig +from tqdm import tqdm +import time + +from pipeline.multi_instruct_data_utils.arguments import add_data_args +from accelerate import Accelerator + +import sys + +# The flag below controls whether to allow TF32 on matmul. This flag defaults to False +# in PyTorch 1.12 and later. +torch.backends.cuda.matmul.allow_tf32 = True + +# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. +torch.backends.cudnn.allow_tf32 = True + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def train_one_epoch( + args, + model, + epoch, + multi_instruct_loader, + tokenizer, + optimizer, + lr_scheduler, + device_id, + accelerator, + wandb, +): + # num_batches_per_epoch = multi_instruct_loader.num_batches + num_batches_per_epoch = len(multi_instruct_loader) + total_training_steps = num_batches_per_epoch * args.num_epochs + + # autocast = get_autocast(args.precision) + # cast_dtype = get_cast_dtype(args.precision) + + media_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] + endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1] + answer_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] + + model.train() + + # setup logging + step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum) + data_time_m = AverageMeter() # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum) + end = time.time() + + # loop through dataloader + for num_steps, (batch_multi_instruct) in tqdm( + enumerate(multi_instruct_loader), + disable=args.rank != 0, + total=total_training_steps, + initial=(epoch * num_batches_per_epoch), + ): + data_time_m.update(time.time() - end) + + global_step = num_steps + epoch * num_batches_per_epoch + #### MULTI_INSTRUCT FORWARD PASS #### + + # images = ( + # batch_multi_instruct["net_input"]["patch_images"] + # .to(device_id, dtype=cast_dtype, non_blocking=True) + # .unsqueeze(1) + # .unsqueeze(1) + # ) + # input_ids = batch_multi_instruct["net_input"]["input_ids"].to( + # device_id, dtype=cast_dtype, non_blocking=True + # ) + # attention_mask = batch_multi_instruct["net_input"]["attention_masks"].to( + # device_id, dtype=cast_dtype, non_blocking=True + # ) + + images = batch_multi_instruct["net_input"]["patch_images"].unsqueeze(1).unsqueeze(1) + input_ids = batch_multi_instruct["net_input"]["input_ids"] + attention_mask = batch_multi_instruct["net_input"]["attention_masks"] + + labels = input_ids.clone() + labels[labels == tokenizer.pad_token_id] = -100 + labels[:, 0] = -100 + + for i in range(labels.shape[0]): + # remove loss for any token before token + label_idx = 0 + while label_idx < labels.shape[1] and labels[i][label_idx] != answer_token_id: + labels[i][label_idx] = -100 + label_idx += 1 + + labels[labels == answer_token_id] = -100 + labels[labels == media_token_id] = -100 + + # labels.to(device_id, dtype=cast_dtype, non_blocking=True) + + with accelerator.accumulate(model): + # with autocast(): + # with accelerator.autocast(): + loss_multi_instruct = model( + vision_x=images, + lang_x=input_ids, + attention_mask=attention_mask, + labels=labels, + )[0] + + # divided_loss_multi_instruct = loss_multi_instruct + + #### BACKWARD PASS #### + accelerator.backward(loss_multi_instruct) + + #### MASK GRADIENTS FOR EMBEDDINGS #### + # Note (anas): Do not apply weight decay to embeddings as it will break this function. + # def mask_embedding(m): + # if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad: + # zero_mask = torch.zeros_like(m.weight.grad) + # # zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) + # # zero_mask[endofchunk_token_id] = torch.ones_like( + # # zero_mask[endofchunk_token_id] + # # ) + # zero_mask[answer_token_id] = torch.ones_like(zero_mask[answer_token_id]) + # m.weight.grad = m.weight.grad * zero_mask + def mask_embedding(m): + if m.weight.requires_grad: + zero_mask = torch.zeros_like(m.weight.grad) + zero_mask[answer_token_id] = torch.ones_like(zero_mask[answer_token_id]) + m.weight.grad = m.weight.grad * zero_mask + + if args.mask_lm_head: + # model.apply(mask_embedding) + model.module.lang_encoder.model.embed_tokens.apply(mask_embedding) + model.module.lang_encoder.lm_head.apply(mask_embedding) + + # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), 1.0) + + # step optimizer and log + # if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or ( + # num_steps == num_batches_per_epoch - 1 + # ): + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # step time and reset end outside of rank 0 + step_time_m.update(time.time() - end) + end = time.time() + + if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1): + if args.rank == 0 and args.report_to_wandb: + # compute within rank 0 + multi_instruct_samples_per_second = args.gradient_accumulation_steps * args.batch_size * args.world_size / step_time_m.val + multi_instruct_samples_per_second_per_gpu = args.gradient_accumulation_steps * args.batch_size / step_time_m.val + + wandb.log( + { + "data_time": data_time_m.avg, + "step_time": step_time_m.avg, + "multi_instruct_samples_per_second": multi_instruct_samples_per_second, + "multi_instruct_samples_per_second_per_gpu": multi_instruct_samples_per_second_per_gpu, + "lr": optimizer.param_groups[0]["lr"], + }, + commit=False, + ) + step_time_m.reset() + data_time_m.reset() + + wandb.log( + { + "loss_multi_instruct": loss_multi_instruct.item(), + "global_step": global_step // args.gradient_accumulation_steps, + }, + commit=True, + ) + + # Log loss to console + if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0: + print( + f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss Multi-Instruct: {loss_multi_instruct.item():.3f}" + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) + parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) + parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str) + parser.add_argument( + "--tokenizer_path", + default="facebook/opt-30b", + type=str, + help="path to tokenizer", + ) + parser.add_argument( + "--cross_attn_every_n_layers", + type=int, + default=1, + help="how often to add a cross-attention layer after each transformer layer", + ) + parser.add_argument( + "--external_save_dir", + type=str, + default=None, + help="set to save model to external path", + ) + parser.add_argument( + "--run_name", + type=str, + default="otter_9b", + help="used to name saving directory and wandb run", + ) + parser.add_argument("--use_media_placement_augmentation", action="store_true") + parser.add_argument("--offline", action="store_true") + parser.add_argument("--num_epochs", type=int, default=1) + parser.add_argument("--logging_steps", type=int, default=100, help="log loss every n steps") + # Sum of gradient optimization batch size + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + help="path to huggingface model or model identifier from local path or huggingface.co", + default=None, + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states", + default=None, + ) + parser.add_argument( + "--overwrite_checkpoint", + action="store_true", + ) + parser.add_argument( + "--delete_previous_checkpoint", + action="store_true", + help="delete previous checkpoint when saving new checkpoint", + ) + parser.add_argument( + "--multi_instruct_path", + type=str, + help="path to multi_instruct dataset, this should be a glob pattern such as vision_language_examples.tsv", + ) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--learning_rate", default=1e-4, type=float) + parser.add_argument( + "--lr_scheduler", + default="constant", + type=str, + help="constant, linear, or cosine", + ) + parser.add_argument("--loss_multiplier_multi_instruct", type=float, default=1.0) + parser.add_argument("--warmup_steps", default=1000, type=int) + parser.add_argument("--warmup_steps_ratio", default=None, type=float) + parser.add_argument("--weight_decay", default=0.1, type=float) + parser.add_argument( + "--precision", + choices=["amp_bf16", "amp_bfloat16", "bf16", "amp", "fp16", "fp32"], + default="amp", + help="Floating point precision.", + ) + # data args + parser.add_argument("--workers", type=int, default=4) + parser.add_argument("--train_num_samples", type=int, default=None) + parser.add_argument("--dataset_resampled", action="store_true") + # distributed training args + parser.add_argument( + "--dist-url", + default="env://", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") + parser.add_argument( + "--horovod", + default=False, + action="store_true", + help="Use horovod for distributed training.", + ) + parser.add_argument( + "--no-set-device-rank", + default=False, + action="store_true", + help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", + ) + # YH: Training detail + parser.add_argument("--mask_lm_head", action="store_true") + # this could potentially save 33GB of all model parameters for otter-9b, including the language and vision model. + parser.add_argument("--save_hf_model", default=False, action="store_true") + # wandb args + parser.add_argument("--report_to_wandb", default=False, action="store_true") + parser.add_argument( + "--wandb_project", + type=str, + ) + parser.add_argument( + "--wandb_entity", + type=str, + ) + parser.add_argument( + "--save_checkpoints_to_wandb", + default=False, + action="store_true", + help="save checkpoints to wandb", + ) + + parser = add_data_args(parser) + args = parser.parse_args() + + if args.save_checkpoints_to_wandb and not args.report_to_wandb: + raise ValueError("save_checkpoints_to_wandb requires report_to_wandb") + + if args.offline: + os.environ["WANDB_MODE"] = "offline" + os.environ["TRANSFORMERS_OFFLINE"] = "1" + + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.world_size > 1: + device_id = init_distributed_device(args) + else: + device_id = 0 + + random_seed(args.seed) + + if args.pretrained_model_name_or_path is not None: + model = FlamingoForConditionalGeneration.from_pretrained( + args.pretrained_model_name_or_path, + device_map={"": torch.cuda.current_device()}, + local_files_only=args.offline, + ) + else: + config = FlamingoConfig.from_json_file("./flamingo/config.json") + model = FlamingoForConditionalGeneration(config=config) + + tokenizer = model.text_tokenizer + image_processor = CLIPImageProcessor() + + # add token to tokenizer + tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "", ""]}) + + model.lang_encoder.resize_token_embeddings(len(tokenizer)) + + args.tokenizer = tokenizer + + random_seed(args.seed, args.rank) + + print(f"Start running training on rank {args.rank}.") + + device_id = args.rank % torch.cuda.device_count() + + multi_instruct_loader = get_data(args, tokenizer, "multi_instruct") + + def get_grouped_params(model): + params_with_wd, params_without_wd = [], [] + + def apply_decay(x): + return "gated_cross_attn_layer" in x and "ff_gate" not in x and "attn_gate" not in x and "norm" not in x and "bias" not in x + + for n, p in model.named_parameters(): + # if p.requires_grad: + if apply_decay(n): + params_with_wd.append(p) + else: + params_without_wd.append(p) + + return [ + {"params": params_with_wd, "weight_decay": args.weight_decay}, + {"params": params_without_wd, "weight_decay": 0.0}, + ] + + # args.train_num_samples = ( + # multi_instruct_dataset.dataloader.num_samples + # if args.train_num_samples is None + # else args.train_num_samples + # ) + + args.train_num_samples = multi_instruct_loader.num_samples if args.train_num_samples is None else args.train_num_samples + + # total_training_steps = ( + # (args.train_num_samples) // (args.batch_size * args.world_size) + # ) * args.num_epochs + total_training_steps = len(multi_instruct_loader) * args.num_epochs + + resume_from_epoch = 0 + # check if a checkpoint exists for this run + args.external_save_dir = os.path.join(args.external_save_dir, args.run_name) if args.external_save_dir else args.run_name + if os.path.exists(f"{args.external_save_dir}") and args.resume_from_checkpoint is None and args.pretrained_model_name_or_path is None: + checkpoint_list = glob.glob(f"{args.external_save_dir}/checkpoint_*.pt") + if len(checkpoint_list) == 0: + print(f"Found no checkpoints for run {args.external_save_dir}.") + else: + args.resume_from_checkpoint = sorted(checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1] + print(f"Found checkpoint {args.resume_from_checkpoint} for run {args.external_save_dir}.") + + if args.rank == 0: + print(f"Loading checkpoint from {args.resume_from_checkpoint}") + checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model_state_dict"], False) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) + resume_from_epoch = checkpoint["epoch"] + 1 + + elif args.resume_from_checkpoint is not None: + print(f"Loading checkpoint from {args.resume_from_checkpoint}") + model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location="cpu"), False) + + optimizer = torch.optim.AdamW(get_grouped_params(model), lr=args.learning_rate) + # model.gradient_checkpointing_enable() + + if args.rank == 0: + print(f"Total training steps: {total_training_steps}") + + args.warmup_steps = total_training_steps * args.warmup_steps_ratio if args.warmup_steps_ratio is not None else args.warmup_steps + + if args.lr_scheduler == "linear": + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=args.warmup_steps // args.gradient_accumulation_steps, + num_training_steps=total_training_steps // args.gradient_accumulation_steps, + ) + elif args.lr_scheduler == "cosine": + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=args.warmup_steps // args.gradient_accumulation_steps, + num_training_steps=total_training_steps // args.gradient_accumulation_steps, + ) + else: + lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps) + + if args.rank == 0 and args.report_to_wandb: + wandb.init( + project=args.wandb_project, + entity=args.wandb_entity, + name=args.run_name, + config=vars(args), + ) + + # import pdb;pdb.set_trace() + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + # multi_instruct_loader = multi_instruct_dataset.dataloader + model, optimizer, lr_scheduler, multi_instruct_loader = accelerator.prepare(model, optimizer, lr_scheduler, multi_instruct_loader) + model.train() + + # From yh + device_id = accelerator.device + + for epoch in range(resume_from_epoch, args.num_epochs): + # multi_instruct_dataset.set_epoch(epoch) + + train_one_epoch( + args=args, + model=model, + epoch=epoch, + tokenizer=tokenizer, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + multi_instruct_loader=multi_instruct_loader, + accelerator=accelerator, + device_id=device_id, + wandb=wandb, + ) + if args.rank == 0: + if not os.path.exists(args.external_save_dir): + os.makedirs(args.external_save_dir) + + accelerator.wait_for_everyone() + + accelerator.wait_for_everyone() + if args.rank == 0: + if not os.path.exists(args.external_save_dir): + os.makedirs(args.external_save_dir) + + unwrapped_model = accelerator.unwrap_model(model) + accelerator.save( + get_checkpoint(model=unwrapped_model), + f"{args.external_save_dir}/final_weights.pt", + ) + if args.report_to_wandb and args.save_checkpoints_to_wandb: + wandb.save(f"{args.external_save_dir}/final_weights.pt") + if args.save_hf_model: + model.save_pretrained(f"{args.external_save_dir}/hf_model") + + +if __name__ == "__main__": + main()