Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Feature] Input metadata dump on crash #13407

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

wallashss
Copy link
Contributor

@wallashss wallashss commented Feb 17, 2025

This PR adds a feature to dump input metadata when vllm engine crashes. In essence, this change is the spiritual successor to #8305 that was recently removed in #12582. However, I tried to solve it differently, since this feature can give us more hints to help debug crashes in production environment. So, I would like to propose it again to the community and give it a second chance.

Summary:

  • The dump is just logged (instead of pickle in [MISC] Dump model runner inputs when crashing #8305)
  • Developed for both engines V0 and V1
  • Dump only tensor metadata, to be able to dump them on CUDA crashes and have their contents obfuscated to not leak sensitive data
  • Introduced custom exceptions, that might be useful for other types of custom error handling
  • Some fields are removed like the prompt or the prompt token ids to avoid log sensitive data in production environment
  • Dump system stats, to check the system status in the moment of the crash. TODO: for V1
  • Print engine config again, to have a chance to get the setup in truncated logs
V0 dump sample
ERROR 02-17 13:19:43 error_report.py:89] Engine crashed, dumping input data
ERROR 02-17 13:19:43 error_report.py:98] V0 LLM engine (v0.6.6.dev221+gcfd3219f5.d20250114) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=1024, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=facebook/opt-125m, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}, use_cached_outputs=False, 
ERROR 02-17 13:19:43 error_report.py:110] Model input for execution as JSON:
ERROR 02-17 13:19:43 error_report.py:111] {"class": "ModelInputForGPUWithSamplingMetadata", "input_tokens": "Tensor(shape=torch.Size([26]), device=cuda:0,dtype=torch.int64)", "input_positions": "Tensor(shape=torch.Size([26]), device=cuda:0,dtype=torch.int64)", "token_types": null, "seq_lens": [6, 8, 6, 6], "query_lens": [6, 8, 6, 6], "lora_mapping": null, "lora_requests": [], "attn_metadata": {"class": "FlashAttentionMetadata", "num_prefills": 4, "num_prefill_tokens": 26, "num_decode_tokens": 0, "slot_mapping": "Tensor(shape=torch.Size([26]), device=cuda:0,dtype=torch.int64)", "multi_modal_placeholder_index_maps": {}, "enable_kv_scales_calculation": true, "seq_lens": [6, 8, 6, 6], "seq_lens_tensor": "Tensor(shape=torch.Size([4]), device=cuda:0,dtype=torch.int32)", "max_prefill_seq_len": 8, "max_decode_seq_len": 0, "context_lens_tensor": "Tensor(shape=torch.Size([4]), device=cuda:0,dtype=torch.int32)", "block_tables": "Tensor(shape=torch.Size([4, 0]), device=cuda:0,dtype=torch.int32)", "use_cuda_graph": false, "max_query_len": 8, "max_decode_query_len": 1, "query_start_loc": "Tensor(shape=torch.Size([5]), device=cuda:0,dtype=torch.int32)", "seq_start_loc": "Tensor(shape=torch.Size([5]), device=cuda:0,dtype=torch.int32)", "_cached_prefill_metadata": null, "_cached_decode_metadata": null, "encoder_seq_lens": null, "encoder_seq_lens_tensor": null, "encoder_seq_start_loc": null, "max_encoder_seq_len": null, "num_encoder_tokens": null, "cross_slot_mapping": null, "cross_block_tables": null}, "prompt_adapter_mapping": null, "prompt_adapter_requests": [], "multi_modal_kwargs": {}, "request_ids_to_seq_ids": {"0": [0], "1": [1], "2": [2], "3": [3]}, "finished_requests_ids": [], "virtual_engine": 0, "async_callback": null, "scheduler_outputs": null, "sampling_metadata": {"class": "SamplingMetadata", "seq_groups": [{"class": "SequenceGroupToSample", "seq_ids": [0], "sampling_params": {"class": "SamplingParams", "sampling_type": "<SamplingType.GREEDY: 0>"}, "seq_data": {"0": {"class": "SequenceData", "prompt_token_ids_len": 6, "output_token_ids_len": 0, "cumulative_logprob": 0.0, "get_num_computed_tokens": 0}}, "seq_len": 6, "query_len": 6, "generator": null, "is_prompt": true, "prompt_logprob_indices": [], "sample_indices": [0]}, {"class": "SequenceGroupToSample", "seq_ids": [1], "sampling_params": {"class": "SamplingParams", "sampling_type": "<SamplingType.GREEDY: 0>"}, "seq_data": {"1": {"class": "SequenceData", "prompt_token_ids_len": 8, "output_token_ids_len": 0, "cumulative_logprob": 0.0, "get_num_computed_tokens": 0}}, "seq_len": 8, "query_len": 8, "generator": null, "is_prompt": true, "prompt_logprob_indices": [], "sample_indices": [1]}, {"class": "SequenceGroupToSample", "seq_ids": [2], "sampling_params": {"class": "SamplingParams", "sampling_type": "<SamplingType.GREEDY: 0>"}, "seq_data": {"2": {"class": "SequenceData", "prompt_token_ids_len": 6, "output_token_ids_len": 0, "cumulative_logprob": 0.0, "get_num_computed_tokens": 0}}, "seq_len": 6, "query_len": 6, "generator": null, "is_prompt": true, "prompt_logprob_indices": [], "sample_indices": [2]}, {"class": "SequenceGroupToSample", "seq_ids": [3], "sampling_params": {"class": "SamplingParams", "sampling_type": "<SamplingType.GREEDY: 0>"}, "seq_data": {"3": {"class": "SequenceData", "prompt_token_ids_len": 6, "output_token_ids_len": 0, "cumulative_logprob": 0.0, "get_num_computed_tokens": 0}}, "seq_len": 6, "query_len": 6, "generator": null, "is_prompt": true, "prompt_logprob_indices": [], "sample_indices": [3]}], "selected_token_indices": "Tensor(shape=torch.Size([4]), device=cuda:0,dtype=torch.int64)", "categorized_sample_indices": {"0": "Tensor(shape=torch.Size([4]), device=cuda:0,dtype=torch.int32)", "1": "Tensor(shape=torch.Size([0]), device=cuda:0,dtype=torch.int32)", "2": "Tensor(shape=torch.Size([0]), device=cuda:0,dtype=torch.int32)"}, "num_prompts": 4, "skip_sampler_cpu_output": false, "reuse_sampling_tensors": false}, "is_prompt": true}
ERROR 02-17 13:19:43 error_report.py:126] Batch info: requests_count=4, requests_prompt_token_ids_lenghts=(6, 8, 6, 6), requests_ids=(0, 1, 2, 3)
ERROR 02-17 13:19:43 error_report.py:133] Errored Batch request #0: request_id=0 prompt_token_ids_lengths=6, params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None), lora_request=None, prompt_adapter_request=None 
ERROR 02-17 13:19:43 error_report.py:133] Errored Batch request #1: request_id=1 prompt_token_ids_lengths=8, params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None), lora_request=None, prompt_adapter_request=None 
ERROR 02-17 13:19:43 error_report.py:133] Errored Batch request #2: request_id=2 prompt_token_ids_lengths=6, params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None), lora_request=None, prompt_adapter_request=None 
ERROR 02-17 13:19:43 error_report.py:133] Errored Batch request #3: request_id=3 prompt_token_ids_lengths=6, params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None), lora_request=None, prompt_adapter_request=None 
ERROR 02-17 13:19:43 error_report.py:143] System stats:
ERROR 02-17 13:19:43 error_report.py:144] Stats(now=1739798383.8002088, num_running_sys=4, num_waiting_sys=0, num_swapped_sys=0, gpu_cache_usage_sys=3.115750116844396e-05, cpu_cache_usage_sys=0.0, cpu_prefix_cache_hit_rate=-1, gpu_prefix_cache_hit_rate=-1, num_prompt_tokens_iter=26, num_generation_tokens_iter=0, num_tokens_iter=26, time_to_first_tokens_iter=[], time_per_output_tokens_iter=[], num_preemption_iter=0, time_e2e_requests=[], time_queue_requests=[], time_inference_requests=[], time_prefill_requests=[], time_decode_requests=[], time_in_queue_requests=[], model_forward_time_requests=[], model_execute_time_requests=[], num_prompt_tokens_requests=[], num_generation_tokens_requests=[], n_requests=[], max_num_generation_tokens_requests=[], max_tokens_requests=[], finished_reason_requests=[], waiting_lora_adapters=[], running_lora_adapters=[], max_lora='0', spec_decode_metrics=None)
V1 Dump Sample
ERROR 02-17 13:19:35 error_report.py:89] Engine crashed, dumping input data
ERROR 02-17 13:19:35 error_report.py:92] V1 LLM engine (v0.6.6.dev221+gcfd3219f5.d20250114) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=1024, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=facebook/opt-125m, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}, 
ERROR 02-17 13:19:35 error_report.py:151] Scheduler output for model execution as JSON:
ERROR 02-17 13:19:35 error_report.py:153] {"class": "SchedulerOutput", "scheduled_new_reqs": [{"class": "NewRequestData", "req_id": "0", "prompt_token_ids_len": 6, "prompt": "", "mm_inputs": [], "mm_hashes": [], "mm_positions": [], "sampling_params": {"class": "SamplingParams"}, "block_ids": [0, 1, 2, 3, 4], "num_computed_tokens": 0, "lora_request": null}, {"class": "NewRequestData", "req_id": "1", "prompt_token_ids_len": 8, "prompt": "", "mm_inputs": [], "mm_hashes": [], "mm_positions": [], "sampling_params": {"class": "SamplingParams"}, "block_ids": [5, 6, 7, 8, 9], "num_computed_tokens": 0, "lora_request": null}, {"class": "NewRequestData", "req_id": "2", "prompt_token_ids_len": 6, "prompt": "", "mm_inputs": [], "mm_hashes": [], "mm_positions": [], "sampling_params": {"class": "SamplingParams"}, "block_ids": [10, 11, 12, 13, 14], "num_computed_tokens": 0, "lora_request": null}, {"class": "NewRequestData", "req_id": "3", "prompt_token_ids_len": 6, "prompt": "", "mm_inputs": [], "mm_hashes": [], "mm_positions": [], "sampling_params": {"class": "SamplingParams"}, "block_ids": [15, 16, 17, 18, 19], "num_computed_tokens": 0, "lora_request": null}], "scheduled_cached_reqs": [], "num_scheduled_tokens": {"0": 6, "1": 8, "2": 6, "3": 6}, "total_num_scheduled_tokens": 26, "scheduled_encoder_inputs": {}, "num_common_prefix_blocks": 0, "finished_req_ids": [], "free_encoder_input_ids": []}

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Feb 17, 2025
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
@joerunde joerunde added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 18, 2025
@tjohnson31415
Copy link
Contributor

@wallashss Thanks for writing up this PR. I think it will be useful to have details for debugging printed to the logs at crash!

When I try out these changes in my dev environment running online mode with

vllm serve meta-llama/Llama-3.2-11B-Vision-Instruct --max-num-seqs 4 --enforce-eager  --max-model-len 8192

and sending a request with a large prompt and requesting prompt_logprobs to trigger an OOM:

curl http://localhost:8000/v1/completions     -H "Content-Type: application/json"     -d '{
        "model": "model",
        "prompt": "'"$(seq -s ' ' 1 1500)"'",
        "max_tokens": 100,
        "prompt_logprobs": 10
    }'

I see the ModelExecutionError error be raised, but then the server seems to hang, never dumping the debug info or exiting... The logs in this case look like:

ERROR 02-20 20:47:16 engine.py:139] vllm.worker.worker_base.ModelExecutionError: Model execution failure,reason: OutOfMemoryError('CUDA out of memory. Tried to allocate 3.35 GiB. GPU 0 has a total capacity of 79.14 GiB of which 1.11 GiB is free. Process 3865663 has 78.01 GiB memory in use. Of the allocated memory 77.27 GiB is allocated by PyTorch, and 246.93 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)')
[rank0]:[W220 20:47:17.410788625 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
   <<and after a few seconds of hanging>>
ERROR 02-20 20:47:25 client.py:300] RuntimeError('Engine process (pid 4107) died.')
ERROR 02-20 20:47:25 client.py:300] NoneType: None

The above seems to happen only when the first reques to the server crashes it. If I send a shortened request first (e.g. prompt from seq -s ' ' 1 100), then it does actually crash on the second request but with an exception in the error reporter:

ERROR 02-20 21:05:09 engine.py:139] During handling of the above exception, another exception occurred:
ERROR 02-20 21:05:09 engine.py:139] 
ERROR 02-20 21:05:09 engine.py:139] Traceback (most recent call last):
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 137, in start
ERROR 02-20 21:05:09 engine.py:139]     self.run_engine_loop()
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 200, in run_engine_loop
ERROR 02-20 21:05:09 engine.py:139]     request_outputs = self.engine_step()
ERROR 02-20 21:05:09 engine.py:139]                       ^^^^^^^^^^^^^^^^^^
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 218, in engine_step
ERROR 02-20 21:05:09 engine.py:139]     raise e
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 209, in engine_step
ERROR 02-20 21:05:09 engine.py:139]     return self.engine.step()
ERROR 02-20 21:05:09 engine.py:139]            ^^^^^^^^^^^^^^^^^^
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/engine/llm_engine.py", line 1393, in step
ERROR 02-20 21:05:09 engine.py:139]     dump_engine_exception(
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/error_report.py", line 124, in dump_engine_exception
ERROR 02-20 21:05:09 engine.py:139]     str(len(r.seq_data[idx].prompt_token_ids))
ERROR 02-20 21:05:09 engine.py:139]             ~~~~~~~~~~^^^^^
ERROR 02-20 21:05:09 engine.py:139] KeyError: 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants