Skip to content

Commit

Permalink
Tensor parallelism for multi-step scheduling (#457)
Browse files Browse the repository at this point in the history
This PR implements tensor parallelism for multi-step scheduling.
  • Loading branch information
tzielinski-habana authored Nov 5, 2024
1 parent ac12d53 commit 653e56c
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.parallel_state import get_world_group
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
Expand Down Expand Up @@ -97,7 +98,10 @@ def subtuple(obj: object,
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if type(obj) is dict:
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename,
' '.join(fields))
Expand Down Expand Up @@ -2049,7 +2053,9 @@ def execute_model(
# not first or last multi-step
return []
# last multi-step
output = self._decode_sampler_outputs(model_input)
output = self._decode_sampler_outputs(
model_input) if self.is_driver_worker else []
torch.hpu.synchronize()
if model_input.is_first_multi_step:
# first multi-step
if self.lora_config:
Expand Down Expand Up @@ -2110,6 +2116,20 @@ def execute_model(
sampling_metadata.skip_sampler_cpu_output = True
self.model.model.sampler.include_gpu_probs_tensor = True
for i in range(num_steps):
if i != 0 and not self.is_driver_worker:
broadcast_data = broadcast_tensor_dict(src=0)
if 'early_exit' in broadcast_data and broadcast_data[
'early_exit']:
return [output] if num_steps == 1 else []
execute_model_kwargs.update({
"input_ids":
broadcast_data["input_ids"],
"positions":
broadcast_data["positions"],
"attn_metadata":
self.trim_attn_metadata(
broadcast_data["attn_metadata"])
})
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
Expand All @@ -2135,7 +2155,7 @@ def execute_model(
htorch.core.mark_step()
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
continue

if model_input.async_callback is not None:
model_input.async_callback()
Expand Down Expand Up @@ -2170,6 +2190,8 @@ def execute_model(
dummy_token = (540, )
data.output_token_ids += (dummy_token)
else:
broadcast_tensor_dict({'early_exit': True},
src=0)
if num_steps == 1:
return [output]
else:
Expand All @@ -2185,6 +2207,12 @@ def execute_model(
"attn_metadata":
self.trim_attn_metadata(result.attn_metadata)
})
model_kwargs_broadcast_data = {
"input_ids": result.input_tokens,
"positions": result.input_positions,
"attn_metadata": vars(result.attn_metadata)
}
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)

if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model' event
Expand All @@ -2199,7 +2227,7 @@ def execute_model(
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
if num_steps == 1:
return [output]
return [output] if self.is_driver_worker else []
else:
return []
return output if type(output) is list else [output]
Expand Down

0 comments on commit 653e56c

Please sign in to comment.