Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confusion About Learnable Prompts in Prompt Learning #17

Open
exoticism4869 opened this issue Feb 14, 2025 · 2 comments
Open

Confusion About Learnable Prompts in Prompt Learning #17

exoticism4869 opened this issue Feb 14, 2025 · 2 comments

Comments

@exoticism4869
Copy link

📌 Issue: Confusion About Learnable Prompts in Prompt Learning

Description:
While working with the PromptLearner class for text-based prompt learning, I encountered confusion regarding how the learnable prompts are inserted into the embeddings.

From the code, it appears that the learnable prompt (self.ctx) is inserted into the token embeddings by directly replacing the first 16 tokens. However, I expected that the learnable prompt would replace the initial descriptive tokens in the text, such as "A whole slide image of" in the example description "A whole slide image of lung adenocarcinoma...".

The issue is that this descriptive part, when tokenized, does not have a fixed length of 16 tokens, which raises the following questions:

📖 Code Reference:

class PromptLearner(nn.Module):
    def __init__(self, classnames, clip_model):
        ...
        self.ctx = nn.Parameter(ctx_vectors)  
        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [name for name in classnames]
        print('prompts:', prompts)

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        self.register_buffer("token_prefix", embedding[:, :1, :])
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])

    def forward(self):
        if self.class_token_position == "end":
            prompts = torch.cat([
                prefix,  # 1 token for the start
                ctx,     # 16 learnable tokens
                suffix,  # remaining tokens from the class description
            ], dim=1)

🤔 My Confusion:

Why does the code directly insert the learnable prompts into the first 16 positions?

💡 Expected Behavior:

I expected the learnable prompts to replace the initial descriptive tokens like "A whole slide image of", but given that tokenizing this phrase doesn't yield exactly 16 tokens, the current implementation seems counterintuitive.

🙏 Additional Context:

If this behavior is intentional, I would appreciate any clarification on the underlying design choice for the fixed token length.

Thank you for your time and assistance!

@Jiangbo-Shi
Copy link
Owner

Thank you for your interest in our work. The learnable prompts (ctx) do not replace the descriptive tokens like "A whole slide image of". They are instead inserted between the prefix and suffix, serving as additional learnable embeddings to enrich the prompt's representation. The fixed length of ctx ensures uniformity and avoids issues caused by variable tokenization.

@exoticism4869
Copy link
Author

Thanks for your quick response! I appreciate your explanation. However, I’ve noticed a potential issue:

I reviewed the CoOp code, where I found a clear explanation of the approach:

Image

In CoOp, they explicitly prepend ‘X’*n_ctx to the class name. However, in your implementation, I didn’t see a similar step. Instead, part of the class name is directly replaced with learnable prompts, which seems counterintuitive to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants