You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm working on enabling the mlperf harness for Llama3.1_405b. However, I found an OOM issue during prefill that I was able to reproduce using llama3.1_f16_8b_tp8. This was found on MI300X-3.
The mlperf dataset contains some long prompts. For example, the first prompt that it sends is 11,478 tokens. In order to have enough pages in the KVCache for this request, I raised the block_seq_stride from 32 to 48.
When invoking prefill for llama3.1_8b_tp8 with the async allocator, we use up all of our RAM:
What happened?
I'm working on enabling the mlperf harness for Llama3.1_405b. However, I found an OOM issue during prefill that I was able to reproduce using
llama3.1_f16_8b_tp8
. This was found onMI300X-3
.The mlperf dataset contains some long prompts. For example, the first prompt that it sends is
11,478
tokens. In order to have enough pages in the KVCache for this request, I raised theblock_seq_stride
from 32 to 48.When invoking prefill for
llama3.1_8b_tp8
with the async allocator, we use up all of our RAM:And get the following error output:
When using
--device_allocator=caching
,iree-run-module
is able to complete successfully:iree-run-module is still able to run prefill using async allocator with a smaller prompt length.
Trace of
iree-run-module
for the failing scenario can be found here.Steps to reproduce your issue
/data/llama3.1/weights/8b/fp16/tp8
onMI300X-3
mkdir inputs_short cd inputs_short wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/tokens.npy wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/seq_ids.npy wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/seq_block_ids.npy wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_0.npy wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_1.npy wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_2.npy wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_3.npy wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_4.npy wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_5.npy wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_6.npy wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/short/cache_state_shard_7.npy
iree-compile llama3.1_8b_fp16.mlir -o llama_3.1_8b_fp16.vmfb --iree-hal-target-device=hip[0] --iree-hal-target-device=hip[1] --iree-hal-target-device=hip[2] --iree-hal-target-device=hip[3] --iree-hal-target-device=hip[4] --iree-hal-target-device=hip[5] --iree-hal-target-device=hip[6] --iree-hal-target-device=hip[7] --iree-hip-target=gfx942 --iree-dispatch-creation-enable-aggressive-fusion=true --iree-global-opt-propagate-transposes=true --iree-opt-aggressively-propagate-transposes=true --iree-opt-data-tiling=false --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-hal-indirect-command-buffers=true --iree-stream-resource-memory-model=discrete --iree-hal-memoization=true --iree-opt-strip-assertions
What component(s) does this issue relate to?
Runtime
Version information
ec6e00b
Additional context
No response
The text was updated successfully, but these errors were encountered: