diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index c50e4e244dffe..fec5f3d01cff8 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -29,6 +29,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import get_world_group from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -97,7 +98,10 @@ def subtuple(obj: object, if to_override is None: to_override = {} fields = set(to_copy) | set(to_override.keys()) - values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if type(obj) is dict: + values = {key: obj[key] for key in fields if key in obj} + else: + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} if typename not in _TYPE_CACHE: _TYPE_CACHE[typename] = collections.namedtuple(typename, ' '.join(fields)) @@ -2049,7 +2053,9 @@ def execute_model( # not first or last multi-step return [] # last multi-step - output = self._decode_sampler_outputs(model_input) + output = self._decode_sampler_outputs( + model_input) if self.is_driver_worker else [] + torch.hpu.synchronize() if model_input.is_first_multi_step: # first multi-step if self.lora_config: @@ -2110,6 +2116,20 @@ def execute_model( sampling_metadata.skip_sampler_cpu_output = True self.model.model.sampler.include_gpu_probs_tensor = True for i in range(num_steps): + if i != 0 and not self.is_driver_worker: + broadcast_data = broadcast_tensor_dict(src=0) + if 'early_exit' in broadcast_data and broadcast_data[ + 'early_exit']: + return [output] if num_steps == 1 else [] + execute_model_kwargs.update({ + "input_ids": + broadcast_data["input_ids"], + "positions": + broadcast_data["positions"], + "attn_metadata": + self.trim_attn_metadata( + broadcast_data["attn_metadata"]) + }) with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, @@ -2135,7 +2155,7 @@ def execute_model( htorch.core.mark_step() # Only perform sampling in the driver worker. if not self.is_driver_worker: - return [] + continue if model_input.async_callback is not None: model_input.async_callback() @@ -2170,6 +2190,8 @@ def execute_model( dummy_token = (540, ) data.output_token_ids += (dummy_token) else: + broadcast_tensor_dict({'early_exit': True}, + src=0) if num_steps == 1: return [output] else: @@ -2185,6 +2207,12 @@ def execute_model( "attn_metadata": self.trim_attn_metadata(result.attn_metadata) }) + model_kwargs_broadcast_data = { + "input_ids": result.input_tokens, + "positions": result.input_positions, + "attn_metadata": vars(result.attn_metadata) + } + broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event @@ -2199,7 +2227,7 @@ def execute_model( is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) if num_steps == 1: - return [output] + return [output] if self.is_driver_worker else [] else: return [] return output if type(output) is list else [output]