Skip to content

Commit

Permalink
add mark_step for decoder layers
Browse files Browse the repository at this point in the history
  • Loading branch information
yisonzhu authored and michalkuligowski committed Jan 7, 2025
1 parent 2d24be7 commit 5efc637
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def flatten(in_list):
return list(itertools.chain(*in_list))


def get_decoder_layer_suffix(model_type):
def get_target_layer_suffix(model_type) -> list[str]:
# This sets the suffix for the hidden layer name, which is controlled by
# VLLM_CONFIG_HIDDEN_LAYERS. The default suffix is "DecoderLayer," which is
# applicable for most language models such as LLaMA, Qwen, and BART. If the
Expand All @@ -145,13 +145,17 @@ def get_decoder_layer_suffix(model_type):
"gpt_bigcode": "BigCodeBlock",
}

return decoder_layer_table.get(model_type, "DecoderLayer")
return [
decoder_layer_table.get(model_type, "DecoderLayer"), "EncoderLayer"
]


def modify_decoder_layer(module: torch.nn.Module,
suffix="DecoderLayer",
n=1,
counter=None):
def modify_model_layers(module: torch.nn.Module,
suffix: list[str],
n=1,
counter=None):
"""Currently add mark_step at the end of specified layers.
"""

def forward_hook(module, args, output):
htorch.core.mark_step()
Expand All @@ -161,12 +165,14 @@ def forward_hook(module, args, output):
counter = [0]

for child_name, child_module in module.named_children():
if child_module.__class__.__name__.endswith(suffix):
if any(
child_module.__class__.__name__.endswith(layer)
for layer in suffix):
counter[0] += 1
if counter[0] % n == 0:
child_module.register_forward_hook(forward_hook)
else:
modify_decoder_layer(child_module, suffix, n, counter)
modify_model_layers(child_module, suffix, n, counter)


def get_path_to_rope(model: torch.nn.Module):
Expand Down Expand Up @@ -753,10 +759,10 @@ def load_model(self) -> None:
hidden_layer_markstep_interval = int(
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))
model_config = getattr(self.model, "config", None)
modify_decoder_layer(
modify_model_layers(
self.model,
get_decoder_layer_suffix(model_config.model_type if
model_config is not None else None),
get_target_layer_suffix(model_config.model_type
if model_config is not None else None),
hidden_layer_markstep_interval)
path_to_rope = get_path_to_rope(self.model)
torch.hpu.synchronize()
Expand Down

0 comments on commit 5efc637

Please sign in to comment.