Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unofficial Training Code sample. #165

Open
Meaw0415 opened this issue Feb 12, 2025 · 2 comments
Open

Unofficial Training Code sample. #165

Meaw0415 opened this issue Feb 12, 2025 · 2 comments

Comments

@Meaw0415
Copy link

Meaw0415 commented Feb 12, 2025

I have implemented some code related to understanding fine-tuning and used sample from inference.py as a reference. Feedback and suggestions are welcome!

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images


accelerator = Accelerator(mixed_precision="bf16")  
device = accelerator.device


model_path = "deepseek-ai/Janus-1.3B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True
).to(device)

vl_gpt.train()  
for name, param in vl_gpt.named_parameters():
    if "gen_embed" in name:  # 
        print(f"Parameter: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}")
        # freeze gen_embed parameters
        param.requires_grad = False
        # check if the parameters are frozen
        print(f"Parameter: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}")


lr = 1e-4 
optimizer = optim.AdamW(vl_gpt.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1)
criterion = nn.CrossEntropyLoss(ignore_index=-100)  
gradient_clip = 1.0


vl_gpt, optimizer = accelerator.prepare(vl_gpt, optimizer)


def train_step(model, optimizer, criterion):
    model.train()
    optimizer.zero_grad()


    conversation = [
        {
            "role": "User",
            "content": "<image_placeholder>\nConvert the formula into latex code.",
            "images": ["images/equation.png"],
        },
        {"role": "Assistant", "content": ""},
    ]


    pil_images = load_pil_images(conversation)


    prepare_inputs = vl_chat_processor(
        conversations=conversation, images=pil_images, force_batchify=True
    ).to(device) 


    model = model.module if hasattr(model, "module") else model  
    model = model.to(torch.bfloat16)

    inputs_embeds = model.prepare_inputs_embeds(**prepare_inputs)


    with accelerator.autocast():  
        outputs = model.language_model(
            inputs_embeds=inputs_embeds,
            attention_mask=prepare_inputs.attention_mask,
        )
        logits = outputs.logits  # (batch_size, seq_len, vocab_size)


    labels = prepare_inputs.input_ids.clone().detach()
    labels[labels == tokenizer.pad_token_id] = -100  
    loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

    accelerator.backward(loss)  
    torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)  
    optimizer.step()
    optimizer.zero_grad()

    return loss.item()

loss = train_step(vl_gpt, optimizer, criterion)
print(f"Training loss: {loss}")

for i in range(10):
    loss = train_step(vl_gpt, optimizer, criterion)
    print(f"Training loss: {loss}")
@Meaw0415
Copy link
Author

I tried to write the generation fine-tuning code, but since I'm not very familiar with distributed training, I haven't finished debugging it yet.

@top-yun
Copy link

top-yun commented Feb 17, 2025

Let me share the method I used—it's a bit legacy, but here it is:

def gen_preprocess(self, images):
    gen_codebooks = []
    for image in images:
        image_tensor = self.gen_resize(image).unsqueeze(0)
        quant, emb_loss, info = self.gen_vision_model.encode(image_tensor)
        gen_codebooks.append(info[2])
    return torch.stack(gen_codebooks, dim=0)

def process_train(
    self,
    question: str = None,
    answer: str = None,
    images: List[Image] = None,
    gen_images: List[Image] = None,
    **kwargs,
):
    """

    Args:
        prompt (str): the formatted prompt;
        conversations (List[Dict]): conversations with a list of messages;
        images (List[ImageType]): the list of images;
        **kwargs:

    Returns:
        outputs (BaseProcessorOutput): the output of the processor,
            - input_ids (torch.LongTensor): [N + image tokens]
            - target_ids (torch.LongTensor): [N + image tokens]
            - images (torch.FloatTensor): [n_images, 3, H, W]
            - image_id (int): the id of the image token
            - num_image_tokens (List[int]): the number of image tokens
    """
    # if self.image_gen_tag in answer:
    #     answer = answer.replace(self.image_gen_tag, self.image_gen_tag*576)
        
    sft_format = question + answer

    # tokenize
    input_ids = self.tokenizer.encode(sft_format)
    input_ids = torch.LongTensor(input_ids)

    # add image tokens to the input_ids
    image_token_mask: torch.BoolTensor = input_ids == self.image_id
    image_indices = image_token_mask.nonzero()
    input_ids, num_image_tokens = self.add_image_token(
        image_indices=image_indices,
        input_ids=input_ids,
    )
    
    gen_token_mask: torch.BoolTensor = input_ids == self.image_gen_id
    gen_indices = gen_token_mask.nonzero()
    input_ids, num_image_gen_tokens = self.add_image_gen_token(
        image_indices=gen_indices,
        input_ids=input_ids,
    )

    # load images
    images_outputs = self.image_processor(images, return_tensors="pt")
    images_gen_outputs = self.gen_preprocess(gen_images)
    
    question_input_ids = self.tokenizer.encode(question)
    question_input_ids = torch.LongTensor(question_input_ids)
    
    question_image_token_mask: torch.BoolTensor = question_input_ids == self.image_id
    question_image_indices = question_image_token_mask.nonzero()
    question_input_ids, _ = self.add_image_token(
        image_indices=question_image_indices,
        input_ids=question_input_ids,
    )
    
    target_input_ids = input_ids.clone()
    # append <image_start_tag> 
    target_input_ids[:len(question_input_ids)+1] = self.ignore_id
    
    target_gen_input_ids = torch.full((len(input_ids),), self.ignore_id)
    # legacy code
    assert torch.sum(input_ids == self.image_gen_id) == len(images_gen_outputs[0])
    target_gen_input_ids[input_ids == self.image_gen_id] = images_gen_outputs[0]
    
    target_input_ids[input_ids == self.image_gen_id] = self.ignore_id
    
    
    prepare = VLChatProcessorTrainOutput(
        sft_format=sft_format,
        input_ids=input_ids,
        pixel_values=images_outputs.pixel_values,
        num_image_tokens=num_image_tokens,
        num_image_gen_tokens=num_image_gen_tokens,
        target_ids=target_input_ids,
        target_gen_ids=target_gen_input_ids,
        gen_codebooks=images_gen_outputs,
    )

I'm simply checking whether the model can transform an image that goes into the understanding encoder (SigLIP) into the generation encoder (VQ model) (image -> image task).
So far, it doesn't seem to be working well, though... 😭

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants