Skip to content

Commit

Permalink
Initial Commit Chunked Prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
hlahkar committed Dec 20, 2024
1 parent da61ecf commit af4b0ad
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 61 deletions.
16 changes: 7 additions & 9 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-3.2-1B",
#"meta-llama/Llama-3.2-1B",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False, True])
@pytest.mark.parametrize("chunked_prefill_token_size", [4,])
@pytest.mark.parametrize("enforce_eager", [True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
#@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -42,14 +42,12 @@ def test_models(
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
attention_backend: str,
monkeypatch,
) -> None:
"""
Checks exact match decode between huggingface model and vllm runner with
chunked prefill.
"""
override_backend_env_variable(monkeypatch, attention_backend)
#override_backend_env_variable(monkeypatch, attention_backend)

max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size
Expand All @@ -76,7 +74,7 @@ def test_models(
)


@multi_gpu_test(num_gpus=2)
'''@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
Expand Down Expand Up @@ -320,4 +318,4 @@ def test_with_prefix_caching_cpu(
chunk_size,
1,
dtype,
)
)'''
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
return x

if device is None:
device = "cpu" if current_platform.is_cpu() else "cuda"
device = "cpu" if current_platform.is_cpu() or current_platform.is_hpu() else "cuda"

Check failure on line 249 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/conftest.py:249:81: E501 Line too long (96 > 80)

if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}
Expand Down
140 changes: 110 additions & 30 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
cross_block_scales: Optional[torch.Tensor] = None
cross_block_usage: Optional[torch.Tensor] = None
cross_attn_bias: Optional[torch.Tensor] = None
decode_slot_mapping: Optional[torch.Tensor] = None
decode_block_list: Optional[torch.Tensor] = None


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Expand Down Expand Up @@ -202,31 +204,96 @@ def forward(
v_scale=v_scale,
)

batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape

query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt:
hidden_size: int = 0
if attn_metadata.num_prefill_tokens > 0:
# prefill preprocessing
prefill_query = query[:attn_metadata.num_prefill_tokens]
prefill_key = key[:attn_metadata.num_prefill_tokens]
prefill_value = value[:attn_metadata.num_prefill_tokens]
hidden_size = prefill_query.shape[-1]
print(prefill_query.shape, hidden_size)
prefill_query = prefill_query.reshape(attn_metadata.num_prefills,
attn_metadata.num_prefill_tokens // attn_metadata.num_prefills,

Check failure on line 216 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:216:81: E501 Line too long (95 > 80)
hidden_size)
hidden_size = prefill_key.shape[-1]
print(prefill_key.shape, hidden_size)
prefill_key = prefill_key.reshape(attn_metadata.num_prefills,
attn_metadata.num_prefill_tokens // attn_metadata.num_prefills,

Check failure on line 221 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:221:81: E501 Line too long (95 > 80)
hidden_size)
hidden_size = prefill_value.shape[-1]
print(prefill_value.shape, hidden_size)
prefill_value = prefill_value.reshape(attn_metadata.num_prefills,
attn_metadata.num_prefill_tokens // attn_metadata.num_prefills,

Check failure on line 226 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:226:81: E501 Line too long (95 > 80)
hidden_size)
prefill_batch_size, prefill_seq_len, prefill_hidden_size = prefill_query.shape
_, seq_len_kv, _ = prefill_key.shape
prefill_query = prefill_query.view(-1, self.num_heads, self.head_size)
prefill_key = prefill_key.view(-1, self.num_kv_heads, self.head_size)
prefill_value = prefill_value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)

if attn_metadata.is_prompt:
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
prefill_key_cache = self.k_cache(prefill_key, key_cache, block_indices,
block_offsets)
prefill_value_cache = self.v_cache(prefill_value, value_cache, block_indices,
block_offsets)
else:
# decode preprocessing
decode_query = query[attn_metadata.num_prefill_tokens:]
decode_key = key[attn_metadata.num_prefill_tokens:]
decode_value = value[attn_metadata.num_prefill_tokens:]
hidden_size = decode_query.shape[-1]
print(decode_query.shape, hidden_size)
decode_query = decode_query.reshape(attn_metadata.num_decode_tokens,
1, hidden_size)
hidden_size = decode_key.shape[-1]
print(decode_key.shape, hidden_size)
decode_key = decode_key.reshape(attn_metadata.num_decode_tokens,
1, hidden_size)
hidden_size = decode_value.shape[-1]
print(decode_value.shape, hidden_size)
decode_value = decode_value.reshape(attn_metadata.num_decode_tokens,
1, hidden_size)
decode_batch_size, decode_seq_len, decode_hidden_size = decode_query.shape
_, seq_len_kv, _ = decode_key.shape
decode_query = decode_query.view(-1, self.num_heads, self.head_size)
decode_key = decode_key.view(-1, self.num_kv_heads, self.head_size)
decode_value = decode_value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
decode_key_cache = self.k_cache(decode_key, key_cache, block_indices,
block_offsets)
decode_value_cache = self.v_cache(decode_value, value_cache, block_indices,
block_offsets)


prompt_output: torch.Tensor = None
decode_output: torch.Tensor = None
batch_size: int = 0
seq_len: int = 0
if attn_metadata.num_prefills > 0:
# Prompt run.
query = prefill_query
key = prefill_key
value = prefill_value
batch_size = prefill_batch_size
seq_len = prefill_seq_len
hidden_size = prefill_hidden_size
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
Expand Down Expand Up @@ -265,8 +332,8 @@ def forward(
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
key_cache=key_cache,
value_cache=value_cache,
key_cache=prefill_key_cache,
value_cache=prefill_value_cache,
block_list=attn_metadata.block_list,
attn_bias=attn_metadata.attn_bias,
scale=self.scale,
Expand All @@ -275,13 +342,19 @@ def forward(
softmax_op=self.softmax,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
prompt_output = out.reshape(batch_size, seq_len, hidden_size)
if attn_metadata.num_decode_tokens > 0:
# Decoding run.
output = HPUPagedAttention.forward_decode(
query = decode_query
key = decode_key
value = decode_value
batch_size = decode_batch_size
seq_len = decode_seq_len
hidden_size = decode_hidden_size
decode_output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
key_cache=decode_key_cache,
value_cache=decode_value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
Expand All @@ -295,7 +368,14 @@ def forward(
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
if not decode_output:
return prompt_output.view(batch_size * seq_len, hidden_size)
elif not prompt_output:
return decode_output.view(batch_size * seq_len, hidden_size)
else:
prompt_output = prompt_output.view(batch_size * seq_len, hidden_size)
decode_output = decode_output.view(batch_size * seq_len, hidden_size)
return torch.cat(prompt_output, decode_output)

def forward_encoder_decoder(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class EngineArgs:
enable_prefix_caching: Optional[bool] = None
disable_sliding_window: bool = False
use_v2_block_manager: bool = True
use_padding_aware_scheduling: bool = current_platform.is_hpu()
use_padding_aware_scheduling: bool = False
swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ def _prepare_seq_groups(
else:
# Decode
prompt_logprob_len = 0
query_len = query_lens[i] if query_lens is not None and len(
query_lens) > 0 else 1
query_len = 1
sample_len = len(seq_ids) * query_len if do_sample else 0

if sampling_params.seed is not None and generators is not None:
Expand Down
Loading

0 comments on commit af4b0ad

Please sign in to comment.