From 93a0d5e2d6f114e70d6e327b975bce4671cea9bb Mon Sep 17 00:00:00 2001 From: Li Bo Date: Thu, 17 Aug 2023 19:16:52 +0800 Subject: [PATCH] add import exception for instruction_following models (#250) --- pipeline/train/instruction_following.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pipeline/train/instruction_following.py b/pipeline/train/instruction_following.py index 84cadc88..2c66a3f9 100644 --- a/pipeline/train/instruction_following.py +++ b/pipeline/train/instruction_following.py @@ -43,6 +43,13 @@ # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True +# Try importing IdeficsForVisionText2Text, and if it's not available, define a dummy class +try: + from transformers import IdeficsForVisionText2Text +except ImportError: + print("IdeficsForVisionText2Text does not exist") + IdeficsForVisionText2Text = type(None) + def random_seed(seed=42, rank=0): torch.manual_seed(seed + rank) @@ -122,17 +129,10 @@ def train_one_epoch(args, model, epoch, mimicit_loaders, tokenizer, optimizer, l labels[labels == answer_token_id] = -100 labels[labels == media_token_id] = -100 - # Try importing IdeficsForVisionText2Text, and if it's not available, define a dummy class - try: - from transformers import IdeficsForVisionText2Text - except ImportError: - print("IdeficsForVisionText2Text does not exist") - IdeficsForVisionText2Text = type(None) - with accelerator.autocast(): unwrapped_model = accelerator.unwrap_model(model) - if isinstance(unwrapped_model, IdeficsForVisionText2Text): + if IdeficsForVisionText2Text is not None and isinstance(unwrapped_model, IdeficsForVisionText2Text): # only for image model max_num_images = images.shape[1] image_attention_mask = get_image_attention_mask(input_ids, max_num_images, tokenizer) @@ -701,7 +701,7 @@ def apply_decay(x): num_training_steps=total_training_steps // args.gradient_accumulation_steps, ) else: - lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps) + lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps) if args.rank == 0 and args.report_to_wandb: wandb.init(