Skip to content

Commit

Permalink
[HPU] Add mark_step configurable for the decoder layer. (#525)
Browse files Browse the repository at this point in the history
We are seeing 10% performance regression in the llama-based model due to
vllm-project#10239. The mark_step()
function needs to be configured differently for each model to achieve
the best performance. For some models, mark_step() for every decoder
step would be optimal, but for other models, it's better to run it every
n-th step. We are adding a counter to only register the hook for every
n-th step, which can be configured with VLLM_CONFIG_HIDDEN_LAYERS
  • Loading branch information
jiminha authored Nov 26, 2024
1 parent 38c2d10 commit b62f1b2
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 24 deletions.
6 changes: 0 additions & 6 deletions vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
if is_hpu:
import os
self.config_hidden_layers = int(
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
Expand Down Expand Up @@ -252,8 +248,6 @@ def forward(
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if is_hpu and i % self.config_hidden_layers == 0:
htorch.core.mark_step()
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
Expand Down
8 changes: 1 addition & 7 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

if is_hpu:
import os
self.config_hidden_layers = int(
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

Expand All @@ -346,13 +341,12 @@ def forward(
if is_hpu:
import habana_frameworks.torch as htorch
htorch.core.mark_step()

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if is_hpu and i % self.config_hidden_layers == 0:
htorch.core.mark_step()
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,6 @@ def forward(
attn_metadata,
residual,
)
if current_platform.is_hpu():
htorch.core.mark_step()

if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
Expand Down
43 changes: 35 additions & 8 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,38 @@ def flatten(in_list):
return list(itertools.chain(*in_list))


def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"):
if module.__class__.__name__.endswith(suffix):
def get_decoder_layer_suffix(model_type):
# This sets the suffix for the hidden layer name, which is controlled by
# VLLM_CONFIG_HIDDEN_LAYERS. The default suffix is "DecoderLayer," which is
# applicable for most language models such as LLaMA, Qwen, and BART. If the
# model's decoder layer name differs from the default, it will need to
# be specified here.
decoder_layer_table = {
"gpt_bigcode": "BigCodeBlock",
}

def forward_hook(module, args, output):
htorch.core.mark_step()
return output
return decoder_layer_table.get(model_type, "DecoderLayer")


def modify_decoder_layer(module: torch.nn.Module,
suffix="DecoderLayer",
n=1,
counter=None):

module.register_forward_hook(forward_hook)
def forward_hook(module, args, output):
htorch.core.mark_step()
return output

if counter is None:
counter = [0]

for child_name, child_module in module.named_children():
modify_decoder_layer(child_module)
if child_module.__class__.__name__.endswith(suffix):
counter[0] += 1
if counter[0] % n == 0:
child_module.register_forward_hook(forward_hook)
else:
modify_decoder_layer(child_module, suffix, n, counter)


class HpuModelAdapter:
Expand Down Expand Up @@ -613,7 +634,13 @@ def load_model(self) -> None:
elif not is_fake_hpu():
self.model = self.model.to("hpu")
htcore.mark_step()
modify_decoder_layer(self.model)

hidden_layer_markstep_interval = int(
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))
modify_decoder_layer(
self.model,
get_decoder_layer_suffix(self.model.config.model_type),
hidden_layer_markstep_interval)
torch.hpu.synchronize()

with HabanaMemoryProfiler() as m_wrap:
Expand Down

0 comments on commit b62f1b2

Please sign in to comment.