Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3.1_fp16_8b_tp8 fails prefill for long prompts when using async allocator #19812

Open
stbaione opened this issue Jan 24, 2025 · 0 comments
Labels
bug 🐞 Something isn't working

Comments

@stbaione
Copy link

What happened?

I'm working on enabling the mlperf harness for Llama3.1_405b. However, I found an OOM issue during prefill that I was able to reproduce using llama3.1_f16_8b_tp8. This was found on MI300X-3.

The mlperf dataset contains some long prompts. For example, the first prompt that it sends is 11,478 tokens. In order to have enough pages in the KVCache for this request, I raised the block_seq_stride from 32 to 48.

When invoking prefill for llama3.1_8b_tp8 with the async allocator, we use up all of our RAM:

Image

And get the following error output:

:1:rocdevice.cpp            :2388: 282642610955d us:  Fail allocation local memory
:1:rocdevice.cpp            :2107: 282642610965d us:  Failed creating memory
:1:memory.cpp               :358 : 282642610973d us:  Video memory allocation failed!
:1:memory.cpp               :318 : 282642610979d us:  Can't allocate memory size - 0xC0880000 bytes!
:1:rocdevice.cpp            :2443: 282642610984d us:  failed to create a svm hidden buffer!
:1:memory.cpp               :1534: 282642610990d us:  Unable to allocate aligned memory
:1:hip_memory.cpp           :329 : 282642613868d us:  Allocation failed : Device memory : required :11820072960 | free :10387193856 | total :206141652992

When using --device_allocator=caching, iree-run-module is able to complete successfully:

Image

iree-run-module is still able to run prefill using async allocator with a smaller prompt length.

Trace of iree-run-module for the failing scenario can be found here.

Steps to reproduce your issue

  1. IRPA files for 8b_tp8 can be found at /data/llama3.1/weights/8b/fp16/tp8 on MI300X-3
  2. Download mlir here
  3. Use this script to download long inputs
#!/bin/bash

mkdir inputs_long
cd inputs_long

wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/long/tokens.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/long/seq_ids.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/long/seq_block_ids.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/long/cache_state_shard_0.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/long/cache_state_shard_1.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/long/cache_state_shard_2.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/long/cache_state_shard_3.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/long/cache_state_shard_4.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/long/cache_state_shard_5.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/long/cache_state_shard_6.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/long/cache_state_shard_7.npy
  1. Use this script to download short inputs
mkdir inputs_short
cd inputs_short

wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/tokens.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/seq_ids.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/seq_block_ids.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_0.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_1.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_2.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_3.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_4.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_5.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_6.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_7.npy
  1. Compile to vmfb
iree-compile llama3.1_8b_fp16.mlir -o llama_3.1_8b_fp16.vmfb --iree-hal-target-device=hip[0]  --iree-hal-target-device=hip[1] --iree-hal-target-device=hip[2] --iree-hal-target-device=hip[3] --iree-hal-target-device=hip[4] --iree-hal-target-device=hip[5] --iree-hal-target-device=hip[6] --iree-hal-target-device=hip[7]  --iree-hip-target=gfx942  --iree-dispatch-creation-enable-aggressive-fusion=true      --iree-global-opt-propagate-transposes=true   --iree-opt-aggressively-propagate-transposes=true      --iree-opt-data-tiling=false   --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))'      --iree-hal-indirect-command-buffers=true   --iree-stream-resource-memory-model=discrete  --iree-hal-memoization=true   --iree-opt-strip-assertions
  1. Run with async allocator (should see OOM error)
ROCR_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 iree-run-module --hip_use_streams=true --module=llama_3.1_8b_fp16.vmfb --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank0.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank1.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank2.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank3.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank4.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank5.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank6.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank7.irpa --device=hip://0 --device=hip://1 --device=hip://2 --device=hip://3 --device=hip://4 --device=hip://5 --device=hip://6 --device=hip://7 --function=prefill_bs4 --input=@long_inputs/tokens.npy --input=@long_inputs/seq_ids.npy --input=@long_inputs/seq_block_ids.npy --input=@long_inputs/cache_state_shard_0.npy --input=@long_inputs/cache_state_shard_1.npy --input=@long_inputs/cache_state_shard_2.npy --input=@long_inputs/cache_state_shard_3.npy --input=@long_inputs/cache_state_shard_4.npy --input=@long_inputs/cache_state_shard_5.npy --input=@long_inputs/cache_state_shard_6.npy --input=@long_inputs/cache_state_shard_7.npy
  1. Run with caching allocator (should run successfully)
ROCR_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 iree-run-module --hip_use_streams=true  --device_allocator=caching --module=llama_3.1_8b_fp16.vmfb --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank0.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank1.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank2.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank3.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank4.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank5.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank6.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank7.irpa --device=hip://0 --device=hip://1 --device=hip://2 --device=hip://3 --device=hip://4 --device=hip://5 --device=hip://6 --device=hip://7 --function=prefill_bs4 --input=@long_inputs/tokens.npy --input=@long_inputs/seq_ids.npy --input=@long_inputs/seq_block_ids.npy --input=@long_inputs/cache_state_shard_0.npy --input=@long_inputs/cache_state_shard_1.npy --input=@long_inputs/cache_state_shard_2.npy --input=@long_inputs/cache_state_shard_3.npy --input=@long_inputs/cache_state_shard_4.npy --input=@long_inputs/cache_state_shard_5.npy --input=@long_inputs/cache_state_shard_6.npy --input=@long_inputs/cache_state_shard_7.npy
  1. Run shorter input prompt with async allocator (should run successfully)
ROCR_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 iree-run-module --hip_use_streams=true --module=llama_3.1_8b_fp16.vmfb --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank0.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank1.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank2.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank3.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank4.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank5.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank6.irpa parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank7.irpa --device=hip://0 --device=hip://1 --device=hip://2 --device=hip://3 --device=hip://4 --device=hip://5 --device=hip://6 --device=hip://7 --function=prefill_bs4 --input=@short_inputs/tokens.npy --input=@short_inputs/seq_ids.npy --input=@short_inputs/seq_block_ids.npy --input=@short_inputs/cache_state_shard_0.npy --input=@short_inputs/cache_state_shard_1.npy --input=@short_inputs/cache_state_shard_2.npy --input=@short_inputs/cache_state_shard_3.npy --input=@short_inputs/cache_state_shard_4.npy --input=@short_inputs/cache_state_shard_5.npy --input=@short_inputs/cache_state_shard_6.npy --input=@short_inputs/cache_state_shard_7.npy

What component(s) does this issue relate to?

Runtime

Version information

ec6e00b

Additional context

No response

@stbaione stbaione added the bug 🐞 Something isn't working label Jan 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant