Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create extra samples with surplus images #272

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 102 additions & 109 deletions open_flamingo/train/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Preprocess and load datasets for training.
"""
from typing import List, Tuple

import functools
import io
Expand All @@ -26,6 +27,11 @@
_SAMPLE_SHUFFLE_SIZE = 5000
_SAMPLE_SHUFFLE_INITIAL = 1000

# Tokens
IMG = "<image>"
EOC = "<|endofchunk|>"


try:
import horovod.torch as hvd
except ImportError:
Expand Down Expand Up @@ -59,9 +65,7 @@ def preprocess_laion_text(sample, tokenizer, max_tokens=32):
Captions are truncated to 32 tokens by default.
"""
tokenizer.padding_side = "right"
sample = [
(f"<image>{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample
]
sample = [(f"{IMG}{s.strip()}{EOC}{tokenizer.eos_token}") for s in sample]
text = tokenizer(
sample,
max_length=max_tokens,
Expand All @@ -72,6 +76,61 @@ def preprocess_laion_text(sample, tokenizer, max_tokens=32):
return text["input_ids"], text["attention_mask"]


def zero_pad_image_tensors(image_tensors, max_num_images: int):
zero_padding = torch.zeros(
(
max_num_images - len(image_tensors),
N_CHANNELS,
image_tensors[0].shape[1],
image_tensors[0].shape[2],
),
dtype=torch.float,
)
padded_image_tensors = torch.cat((image_tensors, zero_padding), dim=0)
return padded_image_tensors


def preprocess_text(text: str) -> str:
text = (
text.replace(EOC, "", 1) # but remove first eoc
.replace(f" {EOC}", EOC) # whitespace cleanup
.replace(f"{IMG} ", IMG)
.replace(f" {IMG}", IMG)
)
return f"{text}{EOC}"


def tokenize_text(tokenizer, text: str, max_tokens: int):
text = f"{text}{tokenizer.eos_token}"
tokenizer.padding_side = "right"
return tokenizer(
text,
max_length=max_tokens,
truncation=True,
padding="max_length",
return_tensors="pt",
)


def sample_validation(tokenizer, text_tensor, min_num_images):
img_tokens_idx = tokenizer.additional_special_tokens_ids[
tokenizer.additional_special_tokens.index(IMG)
]
# reject sequences with too few images (after truncation)
num_images = torch.count_nonzero(text_tensor["input_ids"] == img_tokens_idx)
if num_images < min_num_images:
raise ValueError(f"Fewer than {min_num_images} images in sample")
# 50% chance of keeping single image samples
elif num_images == 1 and random.random() <= 0.5:
raise ValueError("Only one image in sample")

# avoid the situation where there's one <image> token and it's at the end
if num_images == 1 and text_tensor["input_ids"][:, -1] == img_tokens_idx:
raise ValueError(
"Only one image at the end of sample, so labels will all be -100"
)


def preprocess_gpt_interleaved(
info, tokenizer, clip_processor, min_num_images, max_num_images, max_tokens=256
):
Expand All @@ -80,6 +139,7 @@ def preprocess_gpt_interleaved(
"""
text = info["example"]
text = re.sub(r"_!_IMAGE\d+_!_", "<|endofchunk|><image>", text)
text = preprocess_text(text)

# convert images from base64 to PIL
images = []
Expand All @@ -90,49 +150,26 @@ def preprocess_gpt_interleaved(

# preprocess and pad images
images_tensors = preprocess_image(images, clip_processor)
keep_ixs = range(min(len(images_tensors), max_num_images))
images_tensors = images_tensors[keep_ixs]
if len(images_tensors) < max_num_images:
zero_padding = torch.zeros(
(max_num_images - len(images_tensors), 3, 224, 224), dtype=torch.float
)
images_tensors = torch.cat((images_tensors, zero_padding), dim=0)

# preprocess and tokenize text
text = text.replace("<|endofchunk|>", "", 1) # but remove first eoc
# whitespace cleanup
text = (
text.replace(" <|endofchunk|>", "<|endofchunk|>")
.replace("<image> ", "<image>")
.replace(" <image>", "<image>")
)

indices = [m.start() for m in re.finditer("<image>", text)]
if len(indices) > max_num_images:
start_index = indices[max_num_images - 1]
text = text[:start_index]
for pos in range(0, len(images_tensors), max_num_images):
chunk_ixs = range(pos, pos + max_num_images)
chunk_image_tensors = images_tensors[chunk_ixs]
if len(chunk_image_tensors) < max_num_images:
chunk_image_tensors = zero_pad_image_tensors(
chunk_image_tensors, max_num_images
)

text = f"{text}<|endofchunk|>{tokenizer.eos_token}"
tokenizer.padding_side = "right"
text_tensor = tokenizer(
text,
max_length=max_tokens,
truncation=True,
padding="max_length",
return_tensors="pt",
)
# get the start idx of the 1st image token and the end idx of the last eoc token of the chunk
img_tkn_start_idx = [m.start() for m in re.finditer(IMG, text)][:max_num_images]
eoc_tkn_end_idx = [m.end() for m in re.finditer(EOC, text)][:max_num_images]
text = text[img_tkn_start_idx[chunk_ixs[0]] : eoc_tkn_end_idx[chunk_ixs[-1]]]
text_tensor = tokenize_text(tokenizer, text, max_tokens)

# reject sequences with too few images after truncation
num_images = torch.count_nonzero(
text_tensor["input_ids"]
== tokenizer.additional_special_tokens_ids[
tokenizer.additional_special_tokens.index("<image>")
]
)
if num_images < min_num_images:
raise ValueError(f"Fewer than {min_num_images} images in sample")
sample_validation(tokenizer, text_tensor, min_num_images)

return (images_tensors, (text_tensor["input_ids"], text_tensor["attention_mask"]))
yield (
chunk_image_tensors,
(text_tensor["input_ids"], text_tensor["attention_mask"]),
)


def preprocess_interleaved(
Expand Down Expand Up @@ -169,8 +206,7 @@ def preprocess_interleaved(
if len(rawbytes) // 1000 <= MIN_KB:
continue

image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
valid_images.append(image)
valid_images.append(Image.open(io.BytesIO(rawbytes)).convert("RGB"))
valid_image_indices.append(i)

if len(valid_image_indices) == 0:
Expand Down Expand Up @@ -198,74 +234,31 @@ def preprocess_interleaved(
raise ValueError("No images in sample")

# preprocess and pad images
# yield an example for each max_num_images valid images
images_tensors = preprocess_image(images, clip_processor)
keep_ixs = range(min(len(images_tensors), max_num_images))
images_tensors = images_tensors[keep_ixs]
sentence_ixs = [sentence_ixs[ix] for ix in keep_ixs]
if len(images_tensors) < max_num_images:
zero_padding = torch.zeros(
(
max_num_images - len(images_tensors),
N_CHANNELS,
images_tensors[0].shape[1],
images_tensors[0].shape[2],
),
dtype=torch.float,
)
images_tensors = torch.cat((images_tensors, zero_padding), dim=0)

# preprocess and tokenize text
# add in <image> and <eoc> tokens
for ix in sentence_ixs:
sentences[ix] = f"<|endofchunk|><image>{sentences[ix]}"
text = " ".join(sentences)
text = text.replace("<|endofchunk|>", "", 1) # but remove first eoc
# whitespace cleanup
text = (
text.replace(" <|endofchunk|>", "<|endofchunk|>")
.replace("<image> ", "<image>")
.replace(" <image>", "<image>")
)
text = f"{text}<|endofchunk|>{tokenizer.eos_token}"
tokenizer.padding_side = "right"
text_tensor = tokenizer(
text,
max_length=max_tokens,
truncation=True,
padding="max_length",
return_tensors="pt",
)
for pos in range(0, len(images_tensors), max_num_images):
chunk_ixs = range(pos, pos + max_num_images)
chunk_image_tensors = images_tensors[chunk_ixs]
sentence_ixs = [sentence_ixs[ix] for ix in chunk_ixs]
if len(chunk_image_tensors) < max_num_images:
chunk_image_tensors = zero_pad_image_tensors(
chunk_image_tensors, max_num_images
)

# reject sequences with too few images (after truncation)
num_images = torch.count_nonzero(
text_tensor["input_ids"]
== tokenizer.additional_special_tokens_ids[
tokenizer.additional_special_tokens.index("<image>")
]
)
if num_images < min_num_images:
raise ValueError(f"Fewer than {min_num_images} images in sample")
elif (
num_images == 1 and random.random() <= 0.5
): # 50% chance of keeping single image samples
raise ValueError("Only one image in sample")
# preprocess and tokenize text
# add in <image> and <eoc> tokens
for ix in sentence_ixs:
sentences[ix] = f"{EOC}{IMG}{sentences[ix]}"
text = " ".join(sentences)
text = preprocess_text(text)
text_tensor = tokenize_text(tokenizer, text, max_tokens)

# avoid the situation where there's one <image> token and it's at the end
if (
num_images == 1
and text_tensor["input_ids"][:, -1]
== tokenizer.additional_special_tokens_ids[
tokenizer.additional_special_tokens.index("<image>")
]
):
raise ValueError(
"Only one image at the end of sample, so labels will all be -100"
)
sample_validation(tokenizer, text_tensor, min_num_images)

return (
images_tensors,
(text_tensor["input_ids"], text_tensor["attention_mask"]),
)
yield (
chunk_image_tensors,
(text_tensor["input_ids"], text_tensor["attention_mask"]),
)


def get_mmc4_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
Expand Down
Loading