Skip to content

Commit

Permalink
Add limit for decode block bucket size (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfylcek authored Nov 25, 2024
1 parent 61334c5 commit ac9740d
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions vllm_hpu_extension/bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, max_num_seqs, max_num_prefill_seqs, block_size,
self.block_size = block_size
self.max_num_batched_tokens = max_num_batched_tokens
self._setup_buckets()
self.num_hpu_blocks = None

def _setup_buckets(self) -> None:
# FIXME: The default values should be max_model_len
Expand Down Expand Up @@ -111,8 +112,10 @@ def get_padded_prompt_seq_len(self, seq_len):
self.global_state.prompt_seq_bucket_cfg)

def get_padded_decode_num_blocks(self, num_blocks):
return find_bucket(num_blocks,
self.global_state.decode_block_bucket_cfg)
assert self.num_hpu_blocks is not None, "num_hpu_blocks is not set"
bucket_size = find_bucket(num_blocks,
self.global_state.decode_block_bucket_cfg)
return min(bucket_size, self.num_hpu_blocks)

def get_padded_batch_size(self, batch_size, is_prompt):
if is_prompt:
Expand Down

0 comments on commit ac9740d

Please sign in to comment.