Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Feature] Input metadata dump on crash #13407

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
"""
import os
import weakref
from unittest.mock import Mock

import pytest

from vllm import LLM
from vllm.platforms import current_platform
from vllm.v1.engine.core import ModelExecutionV1Error
from vllm.v1.engine.core_client import EngineCoreClient, InprocClient
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
from vllm.worker.worker_base import ModelExecutionError

from ..conftest import VllmRunner
from ..models.utils import check_outputs_equal
Expand Down Expand Up @@ -147,3 +152,47 @@ def test_models_distributed(
name_0="hf",
name_1="vllm",
)


def test_failed_model_execution(vllm_runner) -> None:

def make_client(
multiprocess_mode: bool,
asyncio_mode: bool,
vllm_config, # "VllmConfig"
executor_class, # "Type[Executor]"
log_stats: bool,
) -> "EngineCoreClient":
return InprocClient(vllm_config, executor_class, log_stats)

EngineCoreClient.make_client = Mock(side_effect=make_client)
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:

engine = vllm_model.model.llm_engine
mocked_execute_model = Mock(
side_effect=RuntimeError("Mocked Critical Error"))

if isinstance(engine, LLMEngineV1):
is_v1 = True
engine.engine_core.engine_core.model_executor.execute_model =\
mocked_execute_model
else: # V0
is_v1 = False
engine.model_executor.driver_worker.model_runner.execute_model = \
mocked_execute_model

with pytest.raises(RuntimeError) as exc_info:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
vllm_model.generate_greedy(prompts, 200, use_tqdm=False)
if is_v1:
assert isinstance(exc_info.value, ModelExecutionV1Error)
assert exc_info.value.scheduler_output is not None
else:
assert isinstance(exc_info.value, ModelExecutionError)
assert exc_info.value.model_input is not None
assert "Mocked Critical Error" in str(exc_info.value)
19 changes: 17 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.error_report import dump_engine_exception
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
Expand Down Expand Up @@ -1383,8 +1384,22 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]

outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
try:
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)

except BaseException as err:
stats = self._get_stats(scheduler_outputs=scheduler_outputs)
dump_engine_exception(
err=err,
config=self.vllm_config,
use_cached_outputs=self.use_cached_outputs,
engine_version=0,
stats=stats,
execute_model_req=execute_model_req,
)
# Re-raise exception
raise err

# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
Expand Down
157 changes: 157 additions & 0 deletions vllm/error_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0

import enum
import json
from typing import Any, Dict, Union

import torch

from vllm.config import VllmConfig
from vllm.engine.metrics import Stats
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SequenceData
from vllm.v1.core.scheduler_output import NewRequestData
from vllm.version import __version__ as VLLM_VERSION
from vllm.worker.worker_base import ModelExecutionError

logger = init_logger(__name__)


# Hacky way to make sure we can serialize the object in JSON format
def is_json_serializable(x):
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False


def prepare_object_to_dump(obj):
if isinstance(obj, dict):
return {k: prepare_object_to_dump(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [prepare_object_to_dump(v) for v in obj]
elif isinstance(obj, set):
return [prepare_object_to_dump(v) for v in list(obj)]
elif isinstance(obj, tuple):
return [prepare_object_to_dump(v) for v in obj]
elif isinstance(obj, enum.Enum):
return repr(obj)
elif isinstance(obj, SequenceData):
# Custom representation (based on SequenceData.__repr__)
# to obfuscate some parameters
return {
"class": "SequenceData",
"prompt_token_ids_len": len(obj._prompt_token_ids),
"output_token_ids_len": len(obj.output_token_ids),
"cumulative_logprob": obj.cumulative_logprob,
"get_num_computed_tokens": obj.get_num_computed_tokens()
}

elif isinstance(obj, NewRequestData):
obj_dict: Dict[str, Any] = {'class': type(obj).__name__}
for k, v in obj.__dict__.items():
if k == 'prompt_token_ids':
obj_dict['prompt_token_ids_len'] = len(v)
elif k == 'prompt':
obj_dict['prompt'] = ""
else:
obj_dict[k] = prepare_object_to_dump(v)

return obj_dict
elif isinstance(obj, torch.Tensor):
# We only print the 'draft'of the tensor to not expose sensitive data
# and to get some metadata in case of CUDA illegal memory access
return (f"Tensor(shape={obj.shape}, "
f"device={obj.device},"
f"dtype={obj.dtype})")
elif hasattr(obj, '__dict__'):
obj_dict = {'class': type(obj).__name__}
obj_dict.update(obj.__dict__)
return prepare_object_to_dump(obj_dict)
else:
# Try to make sure we can serialize the object
# to avoid exception
if is_json_serializable(obj):
return obj
else:
return repr(obj)


def dump_engine_exception(err: BaseException,
config: VllmConfig,
engine_version: int,
stats: Union[Stats, None] = None,
use_cached_outputs: Union[bool, None] = None,
execute_model_req: Union[ExecuteModelRequest,
None] = None):

assert engine_version == 0 or engine_version == 1

logger.error("Engine crashed, dumping input data")

if engine_version == 1:
logger.error(
"V1 LLM engine (v%s) with config: %s, ",
VLLM_VERSION,
config,
)
else:
logger.error(
"V0 LLM engine (v%s) with config: %s, "
"use_cached_outputs=%s, ",
VLLM_VERSION,
config,
use_cached_outputs,
)

# For V0
if isinstance(err, ModelExecutionError):
try:
err_json = prepare_object_to_dump(err.model_input)
logger.error("Model input for execution as JSON:")
logger.error(json.dumps(err_json))
except BaseException as exception:
logger.error("Error preparing object to dump")
logger.error(repr(exception))

# In case we do not have a ModelExecutionError we still can
# get information from the batch
if execute_model_req is not None:
batch = execute_model_req.seq_group_metadata_list
requests_count = len(batch)
requests_prompt_token_ids_lenghts = ', '.join([
str(len(r.seq_data[idx].prompt_token_ids))
for idx, r in enumerate(batch)
])
requests_ids = ', '.join([str(r.request_id) for r in batch])
logger.error(
"Batch info: requests_count=%s, "
"requests_prompt_token_ids_lenghts=(%s), "
"requests_ids=(%s)", requests_count,
requests_prompt_token_ids_lenghts, requests_ids)

for idx, r in enumerate(batch):
logger.error(
"Errored Batch request #%s: request_id=%s "
"prompt_token_ids_lengths=%s, "
"params=%s, "
"lora_request=%s, prompt_adapter_request=%s ", idx,
r.request_id, str(len(r.seq_data[idx].prompt_token_ids)),
r.sampling_params, r.lora_request, r.prompt_adapter_request)

# TODO: Have stats for V1
if stats is not None:
logger.error("System stats:")
logger.error(stats)

if engine_version == 1:
from vllm.v1.engine.core import ModelExecutionV1Error
if isinstance(err, ModelExecutionV1Error):
try:
err_json = prepare_object_to_dump(err.scheduler_output)
logger.error("Scheduler output for model execution as JSON:")
logger.error(json.dumps(err_json))
except BaseException as exception:
logger.error("Error preparing object to dump")
logger.error(repr(exception))
20 changes: 19 additions & 1 deletion vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import zmq.asyncio

from vllm.config import VllmConfig
from vllm.error_report import dump_engine_exception
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
Expand All @@ -35,6 +36,14 @@
POLLING_TIMEOUT_S = 2.5


class ModelExecutionV1Error(RuntimeError):
scheduler_output: SchedulerOutput

def __init__(self, *args, scheduler_output):
super().__init__(*args)
self.scheduler_output = scheduler_output


class EngineCore:
"""Inner loop of vLLM's Engine."""

Expand All @@ -46,6 +55,7 @@ def __init__(
):
assert vllm_config.model_config.runner_type != "pooling"

self.config = vllm_config
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)

Expand Down Expand Up @@ -162,7 +172,15 @@ def step(self) -> EngineCoreOutputs:
self.propose_tokens()

scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
try:
output = self.model_executor.execute_model(scheduler_output)
except BaseException as err:
err = ModelExecutionV1Error(
f"Model execution failure,"
f"reason: {repr(err)}",
scheduler_output=scheduler_output)
dump_engine_exception(err, self.config, 1)
raise err
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output) # type: ignore
return engine_core_outputs
Expand Down
51 changes: 35 additions & 16 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,14 @@ def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")


class ModelExecutionError(RuntimeError):
model_input: BroadcastableModelInput

def __init__(self, *args, model_input):
super().__init__(*args)
self.model_input = model_input


@dataclasses.dataclass(frozen=True)
class WorkerInput:
"""Local inputs to each worker. May contain device-specific data. These
Expand Down Expand Up @@ -414,15 +422,20 @@ def execute_model(
and self.observability_config.collect_model_execute_time):
orig_model_execute_time = intermediate_tensors.tensors.get(
"model_execute_time", torch.tensor(0)).item()

output = self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
**kwargs,
)
try:
output = self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
**kwargs,
)
except BaseException as err:
raise ModelExecutionError(
f"Model execution failure,"
f"reason: {repr(err)}",
model_input=model_input) from err

model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
Expand Down Expand Up @@ -472,13 +485,19 @@ def _execute_model_spmd(

kwargs = extract_previous_hidden_states(execute_model_req)

return self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
**kwargs,
)
try:
return self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
**kwargs,
)
except BaseException as err:
raise ModelExecutionError(
f"Model execution failure,"
f"reason: {repr(err)}",
model_input=model_input) from err


class WorkerWrapperBase:
Expand Down