Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mark_step for encoder layers #650

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def flatten(in_list):
return list(itertools.chain(*in_list))


def get_decoder_layer_suffix(model_type):
def get_target_layer_suffix(model_type) -> list[str]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change to something like "get_target_layer_suffix_list" to reflect return type

# This sets the suffix for the hidden layer name, which is controlled by
# VLLM_CONFIG_HIDDEN_LAYERS. The default suffix is "DecoderLayer," which is
# applicable for most language models such as LLaMA, Qwen, and BART. If the
Expand All @@ -145,13 +145,17 @@ def get_decoder_layer_suffix(model_type):
"gpt_bigcode": "BigCodeBlock",
}

return decoder_layer_table.get(model_type, "DecoderLayer")
return [
decoder_layer_table.get(model_type, "DecoderLayer"), "EncoderLayer"
]


def modify_decoder_layer(module: torch.nn.Module,
suffix="DecoderLayer",
n=1,
counter=None):
def modify_model_layers(module: torch.nn.Module,
suffix: list[str],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change to something like "suffix_list" to reflect the underlying type

n=1,
counter=None):
"""Currently add mark_step at the end of specified layers.
"""

def forward_hook(module, args, output):
htorch.core.mark_step()
Expand All @@ -161,12 +165,14 @@ def forward_hook(module, args, output):
counter = [0]

for child_name, child_module in module.named_children():
if child_module.__class__.__name__.endswith(suffix):
if any(
child_module.__class__.__name__.endswith(layer)
for layer in suffix):
counter[0] += 1
if counter[0] % n == 0:
child_module.register_forward_hook(forward_hook)
else:
modify_decoder_layer(child_module, suffix, n, counter)
modify_model_layers(child_module, suffix, n, counter)


def get_path_to_rope(model: torch.nn.Module):
Expand Down Expand Up @@ -753,10 +759,10 @@ def load_model(self) -> None:
hidden_layer_markstep_interval = int(
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))
model_config = getattr(self.model, "config", None)
modify_decoder_layer(
modify_model_layers(
self.model,
get_decoder_layer_suffix(model_config.model_type if
model_config is not None else None),
get_target_layer_suffix(model_config.model_type
if model_config is not None else None),
hidden_layer_markstep_interval)
path_to_rope = get_path_to_rope(self.model)
torch.hpu.synchronize()
Expand Down
Loading