Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolved ALIBI bias regression due to porting flat PA #503

Open
wants to merge 1 commit into
base: habana_main
Choose a base branch
from

Conversation

tannervoas742
Copy link

@tannervoas742 tannervoas742 commented Nov 15, 2024

Requires associated changes on vllm-hpu-extension PR

Changes:

  • Added back alibi biases to decode stage.
  • Optimized ALiBI memory usage.
    • Added environment variable "VLLM_PROMPT_ALIBI_MAX_SEQ_LEN" to allow
      large models to run with restricted prompt lengths.
    • Prompt biases instantiated once in init rather than each
      forward.
    • Prompt and decode biases are shared across encoder/decoder layers.
  • Added environment variable "VLLM_ALIBI_USE_FLOAT32_BIASES" to resolve
    accuracy issue on long sequences.
  • Updated jais, mpt, falcon, baichuan, and bloom to work with ALiBI.
    • Due to bloom's 176B parameter size I was unable to test this model.
      Its changes are the simplest though.
  • Works in lazy and eager mode.
  • ALiBI is restricted to "VLLM_PROMPT_USE_FUSEDSDPA=false", and
    "VLLM_CONTIGUOUS_PA=true".
  • Add position offsets to improve quality on BS > 1 with sequences of
    varying length.
  • BS > 1 may have accuracy issues if on FW < 1.19.0. This is due to
    limitation in softmax. Resolved on FW >= 1.19.0.
  • NTT patch for GQA

Co-authored-by: Tanner Voas [email protected]
Co-authored-by: Haihao Xiang [email protected]
Signed-off-by: Tanner Voas [email protected]

@tannervoas742 tannervoas742 force-pushed the restore_alibi_for_flat_pa_final branch 6 times, most recently from 4a0674d to 3959126 Compare November 18, 2024 10:33
@tannervoas742 tannervoas742 force-pushed the restore_alibi_for_flat_pa_final branch 4 times, most recently from b339767 to 3c3e18a Compare November 27, 2024 03:05
@tannervoas742 tannervoas742 force-pushed the restore_alibi_for_flat_pa_final branch from 3c3e18a to 6c19183 Compare November 28, 2024 01:22
@zhouyuan zhouyuan mentioned this pull request Nov 28, 2024
@tannervoas742 tannervoas742 force-pushed the restore_alibi_for_flat_pa_final branch from 6c19183 to 3cb455d Compare December 5, 2024 19:23
@tannervoas742
Copy link
Author

@itaraban @madamczykhabana @kzawora-intel has anyone gotten a chance to review this PR and the associated one on vllm-hpu-extension. I just pushed out a significant update that minimizes changes to non-alibi code sections. It also has significant accuracy and memory optimization changes.

With the current changes ALiBi is now fully functional as long as FW >= 1.19.0 is being used.

Please help review. Any feedback would be appreciated.

@tannervoas742 tannervoas742 force-pushed the restore_alibi_for_flat_pa_final branch 3 times, most recently from 49fcaaa to 64822b0 Compare December 10, 2024 16:16
@tannervoas742 tannervoas742 force-pushed the restore_alibi_for_flat_pa_final branch from 64822b0 to 684384e Compare December 11, 2024 15:04
vllm/worker/hpu_model_runner.py Outdated Show resolved Hide resolved
vllm/attention/backends/hpu_attn.py Show resolved Hide resolved
@tannervoas742 tannervoas742 force-pushed the restore_alibi_for_flat_pa_final branch 3 times, most recently from 214885e to d3fa482 Compare December 12, 2024 20:01
@tannervoas742
Copy link
Author

@michalkuligowski I have fixed the static code analysis issue as well as updated requirements-hpu.txt

Copy link

@michalkuligowski michalkuligowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tannervoas742 there are still some issues detected, please check (you can try runing format.sh script):
Error: vllm/attention/layer.py:99: error: Too many arguments for "AttentionImpl" [call-arg]
Error: vllm/attention/backends/hpu_attn.py:279: error: Value of type "Optional[Any]" is not indexable [index]
Error: vllm/attention/backends/hpu_attn.py:291: error: Item "None" of "Optional[Any]" has no attribute "unsqueeze" [union-attr]

@tannervoas742
Copy link
Author

@michalkuligowski I see the issues now. I wasn't sure where to view the static code analysis report, but found it. I pushed out an update. Waiting for the code analysis to run again. Will reply here when it's finished and ready for re-review.

@tannervoas742 tannervoas742 force-pushed the restore_alibi_for_flat_pa_final branch 3 times, most recently from 9fac2b5 to ba971fd Compare December 16, 2024 23:37
@tannervoas742
Copy link
Author

@itaraban @michalkuligowski I have updated the PR and ran the script in tools/mypy.sh which passes locally. I also tested the updated version with various ALiBi and non-alibi models. Please re-review. I opened the extension PR again as well. HabanaAI/vllm-hpu-extension#60

Changes:
- Added back alibi biases to decode stage.
- Optimized ALiBI memory usage.
  - Added environment variable "VLLM_PROMPT_ALIBI_MAX_SEQ_LEN" to allow
    large models to run with restricted prompt lengths.
  - Prompt biases instantiated once in __init__ rather than each
    forward.
  - Prompt and decode biases are shared across encoder/decoder layers.
- Added environment variable "VLLM_ALIBI_USE_FLOAT32_BIASES" to resolve
  accuracy issue on long sequences.
- Updated jais, mpt, falcon, baichuan, and bloom to work with ALiBI.
  - Due to bloom's 176B parameter size I was unable to test this model.
    Its changes are the simplest though.
- Works in lazy and eager mode.
- ALiBI is restricted to "VLLM_PROMPT_USE_FUSEDSDPA=false", and
  "VLLM_CONTIGUOUS_PA=true".
- Add position offsets to improve quality on BS > 1 with sequences of
  varying length.
- BS > 1 may have accuracy issues if on FW < 1.19.0. This is due to
  limitation in softmax. Resolved on FW >= 1.19.0.
- NTT patch for GQA

Co-authored-by: Tanner Voas <[email protected]>
Co-authored-by: Haihao Xiang <[email protected]>
Signed-off-by: Tanner Voas <[email protected]>
@tannervoas742 tannervoas742 force-pushed the restore_alibi_for_flat_pa_final branch from ba971fd to b937caf Compare December 17, 2024 17:48
Copy link

@kwisniewski98 kwisniewski98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The biggest issue I have right now is that modifying any file that isn't hpu specific (models, attention backends) will cause it to be hard/impossible to upstream. I didn't want to repeat comment for each file, but I think that changes should be removed from all of them.

@@ -233,6 +233,8 @@ def __init__(
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
tp_rank: Optional[int] = None,
prev_attn: Optional[torch.nn.Module] = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to to modify this class' constructor here? This will be hard to upstream. TP rank can be obtained using: from vllm.distributed import get_tensor_model_parallel_rank and if I'm understanding this correctly, prev_attn is only used here to reuse alibi bias, which can be generated in each layer separately. Alternatively, it can be probably cached.

self.alibi_slopes = None
self.prompt_position_bias = None
# Set upper bound on sequence length
self.max_seq_len = int(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change it to max_seq_len_upper_bound or something similar. For me this looks misleading considering that this variable is reused.

@@ -127,7 +128,7 @@ def __init__(
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that there is a typo, but this should be fixed on upstream and then rebased on our side, so I'd rather leave it as it is. It would be the best to not modify model definitions at all.

@@ -372,6 +390,7 @@ def __init__(
self.lora_config = lora_config

self.quant_config = quant_config
self.use_alibi = position_embedding == "ALIBI"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? I don't see that it is used anywhere

@@ -38,9 +38,11 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
logits_soft_cap: Optional[int] = 4096,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That default value is already set for our implementation, so I'd rather not change it for others

if self.use_alibi:
alibi_blocks = self._compute_alibi_block(block_tables, seq_lens,
len(block_groups))
alibi_blocks = alibi_blocks.to( # type: ignore

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just create tensors on specific devices instead of doing .to? I think that this can cause graph breaks in lazy, but I might be wrong

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants