Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prone to OOM when performing multi-frame inference #88

Open
voidchant opened this issue Dec 13, 2024 · 4 comments
Open

Prone to OOM when performing multi-frame inference #88

voidchant opened this issue Dec 13, 2024 · 4 comments

Comments

@voidchant
Copy link

voidchant commented Dec 13, 2024

Executing the code on Hugging Face with an 8xA100(40G) GPU configuration, but i can only process up to ~4 frames or it will OOM

@voidchant
Copy link
Author

voidchant commented Dec 13, 2024

model = AriaForConditionalGeneration.from_pretrained(model_id_or_path, device_map="auto",torch_dtype=torch.bfloat16)
processor = AriaProcessor.from_pretrained(model_id_or_path)
for video_file in tqdm(os.listdir(prefix)):

    if video_file not in accepted_files:
        continue
    
    video_path = os.path.join(prefix, video_file)
    frames, frame_timestamps = load_video(video_path, num_frames=4)

    contents = get_placeholders_for_videos(frames, frame_timestamps)
    messages = [
        {
            "role": "user",
            "content": [
                *contents,
                {"text": "You are on a the dashboard camera of a moving car. Your goal is to capture the detailed movement of ego car, movements of objects in the view, essential detail and abnormalty on the road", "type": "text"},
            ],
        }
    ]

    text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(text=text, images=frames, return_tensors="pt", max_image_size=490)
    inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
        output = model.generate(
            **inputs,
            max_new_tokens=2048,
            stop_strings=["<|im_end|>"],
            tokenizer=processor.tokenizer,
            do_sample=False,
        )
        output_ids = output[0][inputs["input_ids"].shape[1]:]
        result = processor.decode(output_ids, skip_special_tokens=True)

@xffxff
Copy link
Collaborator

xffxff commented Dec 13, 2024

Have you installed FlashAttention?

You can also add a print(model) statement here to display the attention mechanism used by the ViT model and the language model.
image

@voidchant
Copy link
Author

Thanks for your reply!
Yes, I have installed FlashAttention:

flash-attn                2.7.2.post1

after print(model)

print(model)

AriaForConditionalGeneration(
  (vision_tower): Idefics3VisionTransformer(
    (embeddings): Idefics3VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
      (position_embedding): Embedding(4900, 1152)
    )
    (encoder): Idefics3Encoder(
      (layers): ModuleList(
        (0-26): 27 x Idefics3EncoderLayer(
          (self_attn): Idefics3VisionAttention(
            (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
            (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
            (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
            (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
          )
          (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
          (mlp): Idefics3VisionMLP(
            (activation_fn): PytorchGELUTanh()
            (fc1): Linear(in_features=1152, out_features=4304, bias=True)
            (fc2): Linear(in_features=4304, out_features=1152, bias=True)
          )
          (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
  )
  (multi_modal_projector): AriaProjector(
    (cross_attn): AriaCrossAttention(
      (q_proj): Linear(in_features=1152, out_features=1152, bias=False)
      (k_proj): Linear(in_features=1152, out_features=1152, bias=False)
      (v_proj): Linear(in_features=1152, out_features=1152, bias=False)
      (multihead_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1152, out_features=1152, bias=True)
      )
      (linear): Linear(in_features=1152, out_features=1152, bias=True)
      (dropout): Dropout(p=0, inplace=False)
      (layer_norm): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
      (layer_norm_kv): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    )
    (layer_norm): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
    (feed_forward): AriaProjectorMLP(
      (linear_in): Linear(in_features=1152, out_features=2560, bias=False)
      (linear_out): Linear(in_features=2560, out_features=2560, bias=False)
      (act): NewGELUActivation()
    )
  )
  (language_model): AriaTextForCausalLM(
    (model): AriaTextModel(
      (embed_tokens): Embedding(100352, 2560, padding_idx=2)
      (layers): ModuleList(
        (0-27): 28 x AriaTextDecoderLayer(
          (self_attn): AriaTextSdpaAttention(
            (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
            (k_proj): Linear(in_features=2560, out_features=2560, bias=False)
            (v_proj): Linear(in_features=2560, out_features=2560, bias=False)
            (o_proj): Linear(in_features=2560, out_features=2560, bias=False)
          )
          (mlp): AriaTextMoELayer(
            (router): Linear(in_features=2560, out_features=64, bias=False)
            (experts): AriaGroupedExpertsMLP(
              (fc1): AriaGroupedExpertsGemm()
              (fc2): AriaGroupedExpertsGemm()
            )
            (shared_experts): AriaSharedExpertsMLP(
              (gate_proj): Linear(in_features=2560, out_features=3328, bias=False)
              (up_proj): Linear(in_features=2560, out_features=3328, bias=False)
              (down_proj): Linear(in_features=3328, out_features=2560, bias=False)
              (act_fn): SiLU()
            )
          )
          (input_layernorm): AriaTextRMSNorm((2560,), eps=1e-06)
          (post_attention_layernorm): AriaTextRMSNorm((2560,), eps=1e-06)
        )
      )
      (norm): AriaTextRMSNorm((2560,), eps=1e-06)
      (rotary_emb): AriaTextRotaryEmbedding()
    )
    (lm_head): Linear(in_features=2560, out_features=100352, bias=False)
  )
)


Does this suggest FlashAttention is activated or not?

@xffxff
Copy link
Collaborator

xffxff commented Dec 13, 2024

@voidchant It seems that the ViT model isn’t using flash attention based on the output.

Could you try the following code instead?

model_id_or_path = "rhymes-ai/Aria"
revision = "4844f0b5ff678e768236889df5accbe4967ec845"

model = AutoModelForCausalLM.from_pretrained(
    model_id_or_path, 
    revision=revision, 
    device_map="auto", 
    torch_dtype=torch.bfloat16, 
    trust_remote_code=True, 
    attn_implementation="flash_attention_2"
)

processor = AutoProcessor.from_pretrained(
    model_id_or_path, 
    revision=revision, 
    trust_remote_code=True
)

This code uses a slightly older version of the Aria model. The official Transformers repo has recently added support for the Aria model, but it doesn’t yet support flash attention. As a workaround, we can roll back to this older version for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants