-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unofficial Training Code sample. #165
Comments
I tried to write the generation fine-tuning code, but since I'm not very familiar with distributed training, I haven't finished debugging it yet. |
Let me share the method I used—it's a bit legacy, but here it is: def gen_preprocess(self, images):
gen_codebooks = []
for image in images:
image_tensor = self.gen_resize(image).unsqueeze(0)
quant, emb_loss, info = self.gen_vision_model.encode(image_tensor)
gen_codebooks.append(info[2])
return torch.stack(gen_codebooks, dim=0)
def process_train(
self,
question: str = None,
answer: str = None,
images: List[Image] = None,
gen_images: List[Image] = None,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
# if self.image_gen_tag in answer:
# answer = answer.replace(self.image_gen_tag, self.image_gen_tag*576)
sft_format = question + answer
# tokenize
input_ids = self.tokenizer.encode(sft_format)
input_ids = torch.LongTensor(input_ids)
# add image tokens to the input_ids
image_token_mask: torch.BoolTensor = input_ids == self.image_id
image_indices = image_token_mask.nonzero()
input_ids, num_image_tokens = self.add_image_token(
image_indices=image_indices,
input_ids=input_ids,
)
gen_token_mask: torch.BoolTensor = input_ids == self.image_gen_id
gen_indices = gen_token_mask.nonzero()
input_ids, num_image_gen_tokens = self.add_image_gen_token(
image_indices=gen_indices,
input_ids=input_ids,
)
# load images
images_outputs = self.image_processor(images, return_tensors="pt")
images_gen_outputs = self.gen_preprocess(gen_images)
question_input_ids = self.tokenizer.encode(question)
question_input_ids = torch.LongTensor(question_input_ids)
question_image_token_mask: torch.BoolTensor = question_input_ids == self.image_id
question_image_indices = question_image_token_mask.nonzero()
question_input_ids, _ = self.add_image_token(
image_indices=question_image_indices,
input_ids=question_input_ids,
)
target_input_ids = input_ids.clone()
# append <image_start_tag>
target_input_ids[:len(question_input_ids)+1] = self.ignore_id
target_gen_input_ids = torch.full((len(input_ids),), self.ignore_id)
# legacy code
assert torch.sum(input_ids == self.image_gen_id) == len(images_gen_outputs[0])
target_gen_input_ids[input_ids == self.image_gen_id] = images_gen_outputs[0]
target_input_ids[input_ids == self.image_gen_id] = self.ignore_id
prepare = VLChatProcessorTrainOutput(
sft_format=sft_format,
input_ids=input_ids,
pixel_values=images_outputs.pixel_values,
num_image_tokens=num_image_tokens,
num_image_gen_tokens=num_image_gen_tokens,
target_ids=target_input_ids,
target_gen_ids=target_gen_input_ids,
gen_codebooks=images_gen_outputs,
) I'm simply checking whether the model can transform an image that goes into the understanding encoder (SigLIP) into the generation encoder (VQ model) (image -> image task). |
I have implemented some code related to understanding fine-tuning and used sample from inference.py as a reference. Feedback and suggestions are welcome!
The text was updated successfully, but these errors were encountered: