Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting the attention_mask and past_key_values simultaneously will cause the error. #34835

Closed
4 tasks done
MikeDean2367 opened this issue Nov 20, 2024 · 2 comments
Closed
4 tasks done
Labels

Comments

@MikeDean2367
Copy link

MikeDean2367 commented Nov 20, 2024

System Info

transformers version: 4.46.3
python version: 3.8.0
System: Ubuntu 20.04

Who can help?

@ArthurZucker @stevhliu @zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoModelForCausalLM
import torch
import numpy as np

def generate_mask(length:int, dtype=torch.bfloat16):
    mask = np.ones([length, length], dtype=np.bool)
    mask = np.tril(mask)
    mask = torch.as_tensor(mask)
    min_dtype = torch.finfo(dtype).min
    mask = mask.to(dtype)
    mask[torch.where(mask==0)] = min_dtype
    mask[torch.where(mask==1)] = 0
    return mask.unsqueeze(dim=0).unsqueeze(dim=1).cuda()

@torch.no_grad()
def main():
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-chat-hf",
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    input_ids = [1,2,3,4,5,6]
    past_key_values = None
    model_output = model(
        input_ids=torch.as_tensor(
            [input_ids],
            device="cuda"
        ),
        return_dict=True,
        attention_mask=generate_mask(len(input_ids)),
        past_key_values=past_key_values,
        use_cache=True
    )
    past_key_values = model_output.past_key_values
    input_ids.append(7)
    # below line will raise error
    model_output = model(
        input_ids=torch.as_tensor(
            [input_ids],
            device="cuda"
        ),
        return_dict=True,
        attention_mask=generate_mask(len(input_ids)),
        use_cache=True,
        past_key_values=past_key_values,
    )

if __name__ == '__main__':
    main()

The above code can reproduce my issue. I set the attention_mask and past_key_values simultaneously. Although removing one of them can skip the bug, I need to set the attention_mask and past_key_values due to the research goal. So what should I do? I have read the modeling_llama.py carefully, maybe the code in this line is crucial.

Expected behavior

I want a feasible solution :)

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Nov 21, 2024

@MikeDean2367 hey!

When you are using forward call with past_key_values, you should be passing only the new unprocessed input ids. In this case input_ids = [[7]]. However note that the attention mask should be full, accounting for past processed keys and values. The reason is that in Attention module, we concatenate current keys with past keys which gives us 1 + 6 tokens and thus we need an attention mask of size 7. So the attention should be either a 2D mask with shape [bs, 7] or a 4D mask with shape [bs, heads, 1, 7] in this code snippet

Note however that all of the above applies to calling forward. If you want to call generate(), the correct methods is to concatenate and pass the whole input ids (past + present). Generate will take care of cropping already processed ids and leaving only the new ones.

I agree that this is a bit confusing, and we had recently been asked the same question in #34232 (comment). I will see if there is any particular reason for cropping/not cropping historically and try to make a uniform input format in forward and generate

cc @gante also for when you come back

@MikeDean2367
Copy link
Author

@zucchini-nlp Hi, thank you for your response! I followed your suggestions and have successfully resolved the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants