Skip to content

Commit

Permalink
add import exception for instruction_following models (#250)
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian authored Aug 17, 2023
1 parent 42d65de commit 93a0d5e
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions pipeline/train/instruction_following.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Try importing IdeficsForVisionText2Text, and if it's not available, define a dummy class
try:
from transformers import IdeficsForVisionText2Text
except ImportError:
print("IdeficsForVisionText2Text does not exist")
IdeficsForVisionText2Text = type(None)


def random_seed(seed=42, rank=0):
torch.manual_seed(seed + rank)
Expand Down Expand Up @@ -122,17 +129,10 @@ def train_one_epoch(args, model, epoch, mimicit_loaders, tokenizer, optimizer, l
labels[labels == answer_token_id] = -100
labels[labels == media_token_id] = -100

# Try importing IdeficsForVisionText2Text, and if it's not available, define a dummy class
try:
from transformers import IdeficsForVisionText2Text
except ImportError:
print("IdeficsForVisionText2Text does not exist")
IdeficsForVisionText2Text = type(None)

with accelerator.autocast():
unwrapped_model = accelerator.unwrap_model(model)

if isinstance(unwrapped_model, IdeficsForVisionText2Text):
if IdeficsForVisionText2Text is not None and isinstance(unwrapped_model, IdeficsForVisionText2Text):
# only for image model
max_num_images = images.shape[1]
image_attention_mask = get_image_attention_mask(input_ids, max_num_images, tokenizer)
Expand Down Expand Up @@ -701,7 +701,7 @@ def apply_decay(x):
num_training_steps=total_training_steps // args.gradient_accumulation_steps,
)
else:
lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps)
lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)

if args.rank == 0 and args.report_to_wandb:
wandb.init(
Expand Down

0 comments on commit 93a0d5e

Please sign in to comment.