Skip to content

Commit

Permalink
Merge branch 'habana_main' into private/yishan/bench_throughput_mllama
Browse files Browse the repository at this point in the history
  • Loading branch information
yisonzhu committed Dec 19, 2024
2 parents d78e77d + 88ef381 commit c932cc5
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ tasks:
- name: "gsm8k_cot_llama"
metrics:
- name: "exact_match,strict-match"
value: 0.8317
value: 0.664
- name: "exact_match,flexible-extract"
value: 0.8355
limit: null
value: 0.676
limit: 250
num_fewshot: 8
dtype: "bfloat16"
fewshot_as_multiturn: true
Expand Down
16 changes: 16 additions & 0 deletions .jenkins/lm-eval-harness/inc_unit_scales_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"mode": "QUANTIZE",
"observer": "maxabs",
"scale_method": "unit_scale",
"allowlist": {
"types": [],
"names": []
},
"blocklist": {
"types": [],
"names": [
"lm_head"
]
},
"dump_stats_path": ""
}
8 changes: 5 additions & 3 deletions .jenkins/lm-eval-harness/run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ do
export PT_HPU_ENABLE_LAZY_COLLECTIVES=true
export VLLM_SKIP_WARMUP=true
RANDOM_SUFFIX=$(tr -dc A-Za-z0-9 </dev/urandom | head -c 4; echo)
JUNIT_SUFFIX=""
JUNIT_FAMILY=""
JUNIT_XML=""
if [[ -n "$TEST_RESULTS_DIR" ]]; then
LOG_DIR=$TEST_RESULTS_DIR
LOG_FILENAME="test_${MODEL_CONFIG}_${RANDOM_SUFFIX}.xml"
LOG_PATH="${LOG_DIR}/${LOG_FILENAME}"
JUNIT_SUFFIX="-o junit_family=xunit1 --junitxml=${LOG_PATH}"
JUNIT_FAMILY="-o junit_family=xunit1"
JUNIT_XML="--junitxml=${LOG_PATH}"
fi
pytest -s test_lm_eval_correctness.py "$JUNIT_SUFFIX" || LOCAL_SUCCESS=$?
pytest -s test_lm_eval_correctness.py "$JUNIT_FAMILY" "$JUNIT_XML" || LOCAL_SUCCESS=$?

if [[ $LOCAL_SUCCESS == 0 ]]; then
echo "=== PASSED MODEL: ${MODEL_CONFIG} ==="
Expand Down
8 changes: 3 additions & 5 deletions .jenkins/lm-eval-harness/test_lm_eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@
TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1)


def setup_fp8(model_path, device_type):
flavor = f"g{device_type[-1]}"
normalized_model_name = Path(model_path).parts[-1].lower()
def setup_fp8():
os.environ[
"QUANT_CONFIG"] = \
f"/software/data/vllm-benchmarks/inc/{normalized_model_name}/maxabs_quant_{flavor}.json"
"inc_unit_scales_config.json"


def fail_on_exit():
Expand Down Expand Up @@ -147,7 +145,7 @@ def test_lm_eval_correctness(record_xml_attribute, record_property):

# Set up environment for FP8 inference
if eval_config.get("fp8"):
setup_fp8(eval_config["model_name"], platform)
setup_fp8()
# Launch eval requests.
start_time = time.perf_counter()
results = launch_lm_eval(eval_config)
Expand Down
3 changes: 3 additions & 0 deletions vllm/executor/multiproc_hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def _check_executor_parameters(self):
f"please ensure that world_size ({world_size}) "
f"is less than than max local hpu count ({hpu_device_count})")

def shutdown_inc(self):
self._run_workers("shutdown_inc")

def __del__(self):
self.shutdown()

Expand Down
6 changes: 0 additions & 6 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import TYPE_CHECKING, Optional

import torch

from .interface import Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
Expand All @@ -24,10 +22,6 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True

@staticmethod
def inference_mode():
return torch.no_grad()

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

Expand Down
126 changes: 76 additions & 50 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,40 +169,37 @@ def forward_hook(module, args, output):
modify_decoder_layer(child_module, suffix, n, counter)


def get_names_for_rope(model: torch.nn.Module):
"""Dynamically get layer names needed for cos and sin preparation for rope.
Every model can have a different naming convention for it's layers.
This function dynamically retrieves layer names to access rope layer.
If there's no rope layer, the function returns None.
This function assumes the following layer type layout:
Model -> ModuleList -> Attention -> RotaryEmbedding
def get_path_to_rope(model: torch.nn.Module):
"""Dynamically get the path to the RotaryEmbedding layer in the model.
This function will recursively search through the module hierarchy to find
a RotaryEmbedding layer and return the full path to that layer as a list
of names.
If no such layer is found, it returns None.
"""

def get_child(parent, suffix, is_list=False):
def find_rope_layer(parent, path):
# Base case: check if this parent is None
if parent is None:
return None, None
parent = parent[0] if is_list else parent
for child_name, child_module in parent.named_children():
if child_module.__class__.__name__.endswith(suffix):
return child_name, child_module
return None, None

model_name, model_module = get_child(model, "Model")
layers_name, layers_module = get_child(model_module, "ModuleList")
attn_name, attn_module = get_child(layers_module,
"Attention",
is_list=True)
rope_name, _ = get_child(attn_module, "RotaryEmbedding")

if rope_name is not None:
return {
'model_name': model_name,
'layers_name': layers_name,
'attn_name': attn_name,
'rope_name': rope_name
}
return None

# Check if the current layer is a RotaryEmbedding
if hasattr(parent, 'named_children'):
for child_name, child_module in parent.named_children():
# If the current child is of type RotaryEmbedding,
# return the full path
if child_module.__class__.__name__.endswith("RotaryEmbedding"):
return path + [child_name]
# Otherwise, recurse into this child to check its children
result = find_rope_layer(child_module, path + [child_name])
if result is not None:
return result
return None

# Start the search from the top level model
path_to_rope = find_rope_layer(model, [])

# Return the result if found, otherwise None
return path_to_rope


class HpuModelAdapter:
Expand Down Expand Up @@ -296,11 +293,11 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))

if not is_fake_hpu() and htorch.utils.internal.is_lazy():
if not is_fake_hpu():
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
num_classes=batch_size)
else:
# Unfortunately one_hot on CPU/torch.compile mode/eager mode
# Unfortunately one_hot on CPU
# doesn't handle out of bounds classes so we need to convert
# all negative values to 0 (block_mapping) or bs (block_groups)
block_groups = metadata.block_groups.to(torch.long)
Expand Down Expand Up @@ -353,17 +350,31 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
return attn_metadata

def _prepare_cos_sin(self, positions):
model_name = self.layer_names['model_name']
layers_name = self.layer_names['layers_name']
attn_name = self.layer_names['attn_name']
rope_name = self.layer_names['rope_name']

base_model = getattr(self.model, model_name)
first_model_layer = getattr(base_model, layers_name)[0]
attention_layer = getattr(first_model_layer, attn_name)
rope = getattr(attention_layer, rope_name)

rope.prepare_cos_sin(positions)
"""Navigate through the model using the provided path and call
the prepare_cos_sin method on the 'RotaryEmbedding' layer."""

current_module = self.model # Start from the top level of the model

for layer in self.layer_names:
if layer.isdigit(): # Check if the layer is an index
layer = int(layer)

# Check if the current layer is a name in a module
if isinstance(
layer,
str) and not isinstance(layer, int): # Name-based access
current_module = getattr(current_module, layer)
elif isinstance(layer,
int): # Indexed-based access (like ModuleList)
current_module = list(current_module._modules.values())[layer]

# At the end, we should be at the RotaryEmbedding layer.
if hasattr(current_module, 'prepare_cos_sin'):
current_module.prepare_cos_sin(positions)
else:
raise AttributeError(
"The module at the end of the path does not have \
a 'prepare_cos_sin' method.")

def forward(self, *args, **kwargs):
kwargs = kwargs.copy()
Expand Down Expand Up @@ -744,7 +755,7 @@ def load_model(self) -> None:
get_decoder_layer_suffix(model_config.model_type if
model_config is not None else None),
hidden_layer_markstep_interval)
names_for_rope = get_names_for_rope(self.model)
path_to_rope = get_path_to_rope(self.model)
torch.hpu.synchronize()

with HabanaMemoryProfiler() as m_wrap:
Expand All @@ -753,7 +764,7 @@ def load_model(self) -> None:
self.block_size,
dtype=self.model_config.dtype,
enforce_eager=self.enforce_eager,
layer_names=names_for_rope)
layer_names=path_to_rope)
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
logger.info(msg)

Expand Down Expand Up @@ -2019,6 +2030,19 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

return lora_mask, lora_logits_mask

def add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -2105,8 +2129,8 @@ def execute_model(
def try_revert_dummy_output_tokens():
if len(cache_orig_output_tokens_len) > 0:
# Reuse the original output token ids length
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
for i in range(len(cache_orig_output_tokens_len)):
seq_group_metadata = seq_group_metadata_list[i]
for j, data in seq_group_metadata.seq_data.items():
orig_output_tokens_len = \
cache_orig_output_tokens_len[i][j]
Expand Down Expand Up @@ -2184,16 +2208,18 @@ def try_revert_dummy_output_tokens():
else:
raise RuntimeError(
"seq_group_metadata_list is uninitialized")
for i, seq_group_metadata in enumerate(
for seq_idx, seq_group_metadata in enumerate(
seq_group_metadata_list):
# Skip empty steps
seq_group_metadata.state.current_step += (
num_steps - 2)
# Cache the original output token ids
cache_orig_output_tokens_len.append({})
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[i][j] = \
cache_orig_output_tokens_len[seq_idx][j] = \
len(data.output_token_ids)
seq_group_metadata_list = self.add_dummy_seq(
seq_group_metadata_list, is_prompt=False)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
max_output_len = sampling_metadata.seq_groups[
Expand Down

0 comments on commit c932cc5

Please sign in to comment.