Skip to content

Commit

Permalink
Fix the starCode warmup issue
Browse files Browse the repository at this point in the history
Signed-off-by: yuanwu <[email protected]>
  • Loading branch information
yuanwu2017 committed Dec 1, 2024
1 parent b83419a commit 4586325
Showing 1 changed file with 50 additions and 74 deletions.
124 changes: 50 additions & 74 deletions server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,13 @@
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 8))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 2))


PREFILL_WARMUP_BATCH_SIZE_LIST = []
PREFILL_WARMUP_SEQLEN_LIST = []
DECODE_WARMUP_BATCH_SIZE_LIST = []


def torch_compile_for_eager(func):
if LAZY_MODE == 1:
return func
return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True})


def round_up(warmup_list:list, num) :
i = 0
if len(warmup_list) == 0:
return num

for i in warmup_list:
if num <= i :
break
return i

def round_up(number, k):
return (number + k - 1) // k * k

def to_tensor_indices(indices, device):
return torch.tensor(indices, dtype=torch.long, device=device)
Expand Down Expand Up @@ -372,14 +357,13 @@ def move_data(self, src_batches):
self.set_tensor_groups(dst_tensors)

@classmethod
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "CausalLMBatch":
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
if not all(b.past_key_values is not None for b in batches):
raise ValueError("KV cache not allocated! Cannot recombine before prefill!")

total_requests = sum(len(b) for b in batches)
new_bs = total_requests
if is_warmup is False :
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests)
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)

batch_id = batches[0].batch_id
device = batches[0].input_ids.device
Expand Down Expand Up @@ -481,7 +465,6 @@ def from_pb(
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
is_warmup: bool = False,
) -> "CausalLMBatch":
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
Expand All @@ -503,7 +486,7 @@ def from_pb(
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
# this means that we cannot shift inputs to the left after a long input sequence
# was filtered out
new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
missing_inputs = new_bs - len(inputs)
dummy_inputs = ["?"] * missing_inputs
parameters = [r.parameters for r in pb.requests]
Expand Down Expand Up @@ -533,7 +516,7 @@ def from_pb(
left_padding = max_input_length - input_len
if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
rounded_seq_len = round_up(input_len + 1, PREFILL_BATCH_BUCKET_SIZE)
if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1
else:
Expand Down Expand Up @@ -593,8 +576,8 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:

@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup: bool = False) -> "CausalLMBatch":
return cls.recombine(batches, pad_token_id, is_warmup)
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
return cls.recombine(batches, pad_token_id)

def __len__(self):
return len(self.requests)
Expand Down Expand Up @@ -895,7 +878,7 @@ def forward(

@tracer.start_as_current_span("generate_token")
def generate_token(
self, batches: List[CausalLMBatch], is_warmup: bool = False
self, batches: List[CausalLMBatch]
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
# Results
Expand Down Expand Up @@ -979,20 +962,20 @@ def generate_token(

# Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1:
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup)
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id)
else:
batch = batches[0]

prefill = batch.past_key_values is None

# Check if we need to do any bookkeeping first
if not prefill:
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup)
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)

scenario = 'PREFILL' if prefill else 'GENERATE'
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs:
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs:
self.model.clear_cache()
self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
dbg_trace(
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
assert batch.right_padding > 0, 'No more room for next token!'
Expand Down Expand Up @@ -1194,100 +1177,93 @@ def generate_warmup_batch(self, request, seq_len, batch_size):
for i in range(len(batch.requests) - batch_size):
batch.requests.pop()

return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup=True)
return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)


def warmup(self, request) -> None:
is_warmup = True
MAX_TOTAL_TOKENS = request.max_total_tokens
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
batch = CausalLMBatch.from_pb(request.batch, self.tokenizer, self.dtype, self.device, is_warmup = is_warmup)
batch = self.batch_type.from_pb(request.batch, self.tokenizer, self.dtype, self.device)
try:
# max prefill batch size warmup
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
_, prefill_batch, _ = self.generate_token([batch])
except:
raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
)

del prefill_batch
#warmup decode batch size
max_prefill_batch_size = batch.input_ids.shape[0]
del batch
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)]
decode_batch_size_list.append(max_decode_batch_size)
decode_batch_size_list.sort(reverse=True)

self.limit_hpu_graph = True
try:
for batch_size in range(max_decode_batch_size, 0, -BATCH_BUCKET_SIZE):
for batch_size in decode_batch_size_list:
batches= []
iters = math.floor(batch_size/max_prefill_batch_size)
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
for i in range(iters):
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
_, prefill_batch, _ = self.generate_token([batch])
batches.append(prefill_batch)

if batch_size % max_prefill_batch_size != 0:
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(batch)
_, prefill_batch, _ = self.generate_token([batch])
batches.append(prefill_batch)

_, decode_batch, _ = self.generate_token(batches, is_warmup)
_, decode_batch, _ = self.generate_token(batches)
del decode_batch
batches.clear()
except:
DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1)
self.model.clear_cache()
if len(DECODE_WARMUP_BATCH_SIZE_LIST) > 0:
logger.warning(
f"Not enough memory to warmup all batch size of decode."
f"You need to decrease `--max-batch-total-tokens`"
)
else:
raise RuntimeError(
f"Not enough memory to warmup decode batch_size({max_decode_batch_size})."
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
f"You need to decrease `--max-batch-total-tokens`"
)
DECODE_WARMUP_BATCH_SIZE_LIST.sort()
decode_batch_size_list.sort()
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing decode warmup successfully.\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Decode batch size list:{decode_batch_size_list}\n"
f"Memory stats: {mem_stats} "
)

# Warmup prefill batch_size
max_input_length = request.max_input_length
max_prefill_batch_size = batch.input_ids.shape[0]
seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF

i = 0
while seq_len <= max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
i += 1

if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)

prefill_batch_size_list = []
prefill_seqlen_list = []
#Prefill and decode warmup
try:
for batch_size in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size + 1, PREFILL_BATCH_BUCKET_SIZE):
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
for batch_size in range(max_prefill_batch_size, 0, -PREFILL_BATCH_BUCKET_SIZE):
prefill_batch_size_list.append(batch_size)
for seq_len in range(max_input_length, 0, -PAD_SEQUENCE_TO_MULTIPLE_OF):
prefill_seqlen_list.append(seq_len)
batch = self.generate_warmup_batch(request, seq_len, batch_size)
_, prefill_batch, _ = self.generate_token([batch])
del batch
del prefill_batch
except:
raise RuntimeError(
f"Not enough memory to run following prefill batch_size."
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
f"Prefill batch size list:{prefill_batch_size_list}"
f"Prefill sequence length list:{prefill_seqlen_list}"
f"You need to decrease `--max-batch-prefill-tokens`"
)

prefill_batch_size_list.sort()
prefill_seqlen_list.sort()
limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
if limit_hpu_graph == False:
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing prefill and decode warmup successfully.\n"
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
f"Prefill batch size list:{prefill_batch_size_list}\n"
f"Prefill sequence length list:{prefill_seqlen_list}\n"
f"Memory stats: {mem_stats} "
)

Expand Down

0 comments on commit 4586325

Please sign in to comment.