Skip to content

Commit

Permalink
Merge pull request #283 from mlfoundations/media_token_fix
Browse files Browse the repository at this point in the history
train_utils media token fix
  • Loading branch information
anas-awadalla authored Dec 2, 2023
2 parents c5feb97 + fa6af69 commit eb6b8aa
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion open_flamingo/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from data_utils import DataInfo
import random
import numpy as np
import torch.nn as nn


def train_one_epoch(
Expand Down Expand Up @@ -78,7 +79,7 @@ def train_one_epoch(
f"{datasets[dataset_ix].name}_num_tokens"
] = attention_mask.sum().item()
batch_metadata_to_log[f"{datasets[dataset_ix].name}_num_images"] = (
(input_ids == model.media_token_id).sum().item()
(input_ids == unwrap_model(model).media_token_id).sum().item()
)

# forward pass
Expand Down Expand Up @@ -188,6 +189,16 @@ def random_seed(seed=42, rank=0):
random.seed(seed + rank)


def unwrap_model(model):
"""
Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
"""
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
return model.module
else:
return model


################################
# Helper functions for logging #
################################
Expand Down

0 comments on commit eb6b8aa

Please sign in to comment.