From 18e77a5cc7aeab1784c5c8d6f4cb2b6b8d044078 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 5 Jun 2024 15:28:10 +0200 Subject: [PATCH 01/18] wip --- proto/v3/generate.proto | 24 ++++ router/src/infer/v3/queue.rs | 3 + router/src/infer/v3/scheduler.rs | 35 +++++- router/src/lib.rs | 3 + .../models/causal_lm.py | 1 + .../models/flash_causal_lm.py | 103 +++++++----------- .../models/idefics_causal_lm.py | 1 + server/text_generation_server/models/mamba.py | 1 + .../models/seq2seq_lm.py | 1 + server/text_generation_server/models/types.py | 2 + .../models/vlm_causal_lm.py | 4 +- 11 files changed, 112 insertions(+), 66 deletions(-) diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 01cc43fd320..d57fbbad4b1 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -17,6 +17,8 @@ service TextGenerationService { rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Decode token for a list of prefilled batches rpc Decode (DecodeRequest) returns (DecodeResponse); + /// Update batch + rpc Update(UpdateRequest) returns (UpdateResponse); /// Health check rpc Health (HealthRequest) returns (HealthResponse); } @@ -198,6 +200,8 @@ message Generation { optional GeneratedText generated_text = 4; /// Top tokens repeated Tokens top_tokens = 5; + /// Current length of the request: prompt tokens + number of generated tokens until this point + uint32 current_length = 6; } message FilterBatchRequest { @@ -251,6 +255,26 @@ message DecodeResponse { optional uint64 concat_ns = 6; } +message ExtendedRequest { + /// Request ID + uint64 request_id = 1; + /// Paged attention blocks to add + repeated uint32 blocks = 2; + /// Paged attention slots to add + repeated uint32 slots = 3; +} + +message UpdateRequest { + /// Batch ID + uint64 batch_id = 1; + /// Requests to update + repeated ExtendedRequest extend_requests = 2; + /// Requests to terminate + repeated uint64 terminated_request_ids = 3; +} + +message UpdateResponse {} + message WarmupRequest { /// Batch to warmup on Batch batch = 1; diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 0b66142abb1..1522679445c 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -33,6 +33,8 @@ pub(crate) struct Entry { pub batch_time: Option, /// Block Allocation pub block_allocation: Option, + /// Current length (in tokens) of the request (prompt tokens + generated_tokens) + pub current_length: u32 } /// Request Queue @@ -498,6 +500,7 @@ mod tests { queue_time: Instant::now(), batch_time: None, block_allocation: None, + current_length: 0, }; (entry, receiver_tx) } diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index ad03dd8375f..bf52e69f72a 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -88,6 +88,7 @@ impl Scheduler for SchedulerV3 { queue_time: Instant::now(), batch_time: None, block_allocation: None, + current_length: input_length, }); // Notify the background task that we have a new entry in the queue that needs @@ -287,6 +288,8 @@ async fn decode( // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); + filter_update_allocations(client, entries).await; + // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; @@ -355,8 +358,9 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap, entries: &mut IntMap) { + // let mut extend_entries = Vec::with_capacity(entries.len()); + // let mut finish_entries = Vec::with_capacity(entries.len()); + + // for (request_id, entry) in entries.into_iter() { + // tracing::info!("Allocation {}; Current Length: {}", entry.block_allocation.as_ref().unwrap().allocated_tokens, entry.current_length); + // + // if let Some(block_allocation) = &mut entry.block_allocation { + // tracing::info!("Allocation {:?}", block_allocation); + // + // if entry.current_length > block_allocation.allocated_tokens { + // // We need to add new blocks to this entry + // let remaining_tokens = block_allocation.total_tokens - entry.current_length; + // match block_allocation.extend(remaining_tokens).await { + // true => { + // + // }, + // false => { + // + // } + // } + // } + // } + // } +} + /// Send responses through the `entry` response channel fn send_responses( generation: Generation, diff --git a/router/src/lib.rs b/router/src/lib.rs index b6902c497c1..52c5aa461fd 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1085,6 +1085,8 @@ pub(crate) enum FinishReason { EndOfSequenceToken, #[schema(rename = "stop_sequence")] StopSequence, + #[schema(rename = "out_of_pages")] + OutOfPages } impl std::fmt::Display for FinishReason { @@ -1093,6 +1095,7 @@ impl std::fmt::Display for FinishReason { FinishReason::Length => write!(f, "length"), FinishReason::EndOfSequenceToken => write!(f, "eos_token"), FinishReason::StopSequence => write!(f, "stop_sequence"), + FinishReason::OutOfPages => write!(f, "out_of_pages"), } } } diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index e896c831bd3..2fe0f56e943 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -746,6 +746,7 @@ def generate_token( ), generated_text, top_tokens, + new_input_length ) generations.append(generation) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d16d371068e..da5fa9dbf93 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -79,8 +79,6 @@ class FlashCausalLMBatch(Batch): # Paged Attention values # Set when creating the batch - # CPU tensor of length b indicating the start of each sequence in slots - start_slots: torch.Tensor # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode slot_indices: torch.Tensor @@ -88,8 +86,10 @@ class FlashCausalLMBatch(Batch): block_tables: List[List[int]] # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: torch.Tensor + # list of length b of list of length s_i + slots: List[List[int]] # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots: torch.Tensor + slots_tensor: torch.Tensor max_seqlen: int @@ -154,7 +154,6 @@ def from_tokenized( sliding_window = get_sliding_windows() position_ids = [] cu_seqlen_prefill = [0] - start_slots = [] slot_indices = [] prefill_cache_indices = [] @@ -176,7 +175,6 @@ def from_tokenized( # Cumulative length cumulative_length = 0 - cumulative_max_length = 0 prefill_out_cumulative_length = 0 num_blocks = 0 @@ -186,6 +184,7 @@ def from_tokenized( block_tables = [] slots = [] + flat_slots = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -204,6 +203,9 @@ def from_tokenized( input_length = len(tokenized_input) input_lengths.append(input_length) + speculative_length = get_speculate() + speculative_length = 0 if speculative_length is None else speculative_length + prefix_offsets.append(input_length - 5) read_offsets.append(input_length) @@ -226,13 +228,10 @@ def from_tokenized( top_n_tokens.append(r.top_n_tokens) # Paged attention - # Remove one as the first token des not have a past - speculative_length = get_speculate() - speculative_length = 0 if speculative_length is None else speculative_length - total_tokens = input_length + max_new_tokens - 1 + speculative_length - # blocks and slots can be empty (for example in warmup) if not r.blocks: + # Remove one as the first token des not have a past + total_tokens = input_length + max_new_tokens - 1 + speculative_length needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) @@ -247,15 +246,15 @@ def from_tokenized( request_slots = r.slots block_tables.append(request_blocks) - slots.extend(request_slots[:total_tokens]) num_blocks += len(request_blocks) - start_slots.append(cumulative_max_length) request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, + len(flat_slots), + len(flat_slots) + input_length, dtype=torch.int64, ) + slots.append(request_slots) + flat_slots.extend(request_slots) slot_indices.append(request_slot_indices) # Create tensor to slice into the kv tensor in prefill @@ -289,7 +288,6 @@ def from_tokenized( # Update cumulative_length += input_length - cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max( @@ -299,7 +297,6 @@ def from_tokenized( next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -356,7 +353,7 @@ def from_tokenized( top_n_tokens, device=device, dtype=torch.int64 ) - slots = torch.tensor(slots, dtype=torch.int64, device=device) + slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) @@ -372,11 +369,11 @@ def from_tokenized( position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, prefill_cache_indices=prefill_cache_indices, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -423,18 +420,13 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Used to index into tensors indices = [] - # slots to keep after filtering - slot_filtering_indices = torch.zeros( - self.slots.shape[0], dtype=torch.bool, device=device - ) - - # Create on CPU to only move to GPU once instead of at every copy - slot_indices = torch.empty(len(request_ids), dtype=torch.int64) + slot_indices = [] max_seqlen = 0 requests = [] - start_slots = [] block_tables = [] + slots = [] + flat_slots = [] all_input_ids = [] input_lengths = [] @@ -446,8 +438,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": num_blocks = 0 max_blocks = 0 - # Cumulative length - cumulative_max_length = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -471,27 +461,17 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": top_n_tokens.append(self.top_n_tokens[idx]) - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) - start_slots.append(cumulative_max_length) - # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_input_length - 1 + # List of slots allocated for this request + request_slots = self.slots[idx] + slots.append(request_slots) - # Set slice - slot_filtering_indices[ - self.start_slots[idx] : self.start_slots[idx] - + request_input_length - + remaining_tokens - - 1 - ] = True - - cumulative_max_length += request_input_length + remaining_tokens - 1 + # Index + slot_indices.append(len(flat_slots) + request_input_length - 1) + flat_slots.extend(request_slots) max_blocks = max(max_blocks, len(request_block_table)) @@ -501,17 +481,15 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( self.speculative_ids[indices] if self.speculative_ids is not None else None ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) - - # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) + # Allocate on GPU + slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) + slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) return type(self)( batch_id=self.batch_id, @@ -521,11 +499,11 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -560,13 +538,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch max_seqlen = 0 for b in batches: total_batch_size += len(b) - total_slots += len(b.slots) + total_slots += len(b.slots_tensor) num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) + # When we filter, we do not recompute this value so we do so here max_length = max( max_length, max( @@ -582,7 +561,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots = batches[0].slots.new_empty(total_slots) + slots_tensor = batches[0].slots_tensor.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( total_batch_size @@ -597,7 +576,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch total_batch_size, ) - start_slots = [] + slots = [] block_tables = [] all_input_ids = [] @@ -627,7 +606,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) + slots_end_index = cumulative_slots + len(batch.slots_tensor) # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids @@ -635,7 +614,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - slots[slots_start_index:slots_end_index] = batch.slots + slots_tensor[slots_start_index:slots_end_index] = batch.slots_tensor all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] @@ -645,8 +624,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] - start_slots.append(batch.start_slots + cumulative_slots) - + slots.extend(batch.slots) block_tables.extend(batch.block_tables) all_input_ids.extend(batch.all_input_ids) @@ -662,9 +640,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch # Update cumulative_batch_size += len(batch) - cumulative_slots += len(batch.slots) - - start_slots = torch.concat(start_slots) + cumulative_slots += len(batch.slots_tensor) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, @@ -688,11 +664,11 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -993,7 +969,7 @@ def forward( cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] + slots = batch.slots_tensor[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1032,7 +1008,7 @@ def forward( cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] + slots = batch.slots_tensor[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1374,6 +1350,7 @@ def generate_token( ), generated_text, top_tokens, + input_length + n_accepted_ids ) generations.append(generation) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index f507d669936..44b2189921c 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -829,6 +829,7 @@ def generate_token( ), generated_text, top_tokens, + new_input_length ) generations.append(generation) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 3133a137d9a..8182eb46936 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -775,6 +775,7 @@ def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, in ), generated_text, top_tokens, + new_input_length ) generations.append(generation) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3bd095564c3..74ea2dabc18 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -801,6 +801,7 @@ def generate_token( ), generated_text, top_tokens, + new_decoder_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 339b733b5f6..1c7a157a453 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -84,6 +84,7 @@ class Generation: generated_text: Optional[GeneratedText] # Optional for now, since it's not yet supported for every model. top_tokens: Optional[List[Tokens]] + current_length: int def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( @@ -100,4 +101,5 @@ def to_pb(self) -> generate_pb2.Generation: if self.top_tokens is not None else None ), + current_length=self.current_length, ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 59a6fab1cd0..b1ccd140e36 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -228,7 +228,7 @@ def forward( cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] + slots = batch.slots_tensor[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -267,7 +267,7 @@ def forward( cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] + slots = batch.slots_tensor[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices From 1cc86930a682f1a1f63434e798220bf2ebe7ca21 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 5 Jun 2024 17:01:06 +0200 Subject: [PATCH 02/18] wip --- proto/v3/generate.proto | 33 +++----- router/client/src/v3/client.rs | 4 +- router/client/src/v3/mod.rs | 2 +- router/client/src/v3/sharded_client.rs | 6 +- router/src/infer/mod.rs | 3 + router/src/infer/v3/block_allocator.rs | 8 ++ router/src/infer/v3/queue.rs | 2 +- router/src/infer/v3/scheduler.rs | 84 +++++++++++-------- router/src/lib.rs | 3 - router/src/server.rs | 1 + .../models/causal_lm.py | 8 +- .../models/flash_causal_lm.py | 67 +++++++-------- .../models/idefics_causal_lm.py | 8 +- server/text_generation_server/models/mamba.py | 8 +- .../models/seq2seq_lm.py | 6 +- server/text_generation_server/models/types.py | 2 +- .../models/vlm_causal_lm.py | 6 +- server/text_generation_server/server.py | 2 +- 18 files changed, 139 insertions(+), 114 deletions(-) diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index d57fbbad4b1..192cd111bb7 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -17,8 +17,6 @@ service TextGenerationService { rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Decode token for a list of prefilled batches rpc Decode (DecodeRequest) returns (DecodeResponse); - /// Update batch - rpc Update(UpdateRequest) returns (UpdateResponse); /// Health check rpc Health (HealthRequest) returns (HealthResponse); } @@ -204,11 +202,20 @@ message Generation { uint32 current_length = 6; } +message UpdatedRequest { + /// Request ID + uint64 id = 1; + /// Paged attention blocks + repeated uint32 blocks = 2; + /// Paged attention slots + repeated uint32 slots = 3; +} + message FilterBatchRequest { /// Batch ID uint64 batch_id = 1; /// Requests to keep - repeated uint64 request_ids = 2; + repeated UpdatedRequest updated_requests = 2; } message FilterBatchResponse { @@ -255,26 +262,6 @@ message DecodeResponse { optional uint64 concat_ns = 6; } -message ExtendedRequest { - /// Request ID - uint64 request_id = 1; - /// Paged attention blocks to add - repeated uint32 blocks = 2; - /// Paged attention slots to add - repeated uint32 slots = 3; -} - -message UpdateRequest { - /// Batch ID - uint64 batch_id = 1; - /// Requests to update - repeated ExtendedRequest extend_requests = 2; - /// Requests to terminate - repeated uint64 terminated_request_ids = 3; -} - -message UpdateResponse {} - message WarmupRequest { /// Batch to warmup on Batch batch = 1; diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 9a3892fb186..8cefd3137b2 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -90,11 +90,11 @@ impl Client { pub async fn filter_batch( &mut self, batch_id: u64, - request_ids: Vec, + updated_requests: Vec, ) -> Result> { let request = tonic::Request::new(FilterBatchRequest { batch_id, - request_ids, + updated_requests, }) .inject_context(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs index 4a1296a2247..df2bb380734 100644 --- a/router/client/src/v3/mod.rs +++ b/router/client/src/v3/mod.rs @@ -8,6 +8,6 @@ pub use client::Client; pub use pb::generate::v3::{ input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, - StoppingCriteriaParameters, Tokens, + StoppingCriteriaParameters, Tokens, UpdatedRequest, }; pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 94002f55064..a066176ce5e 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -10,7 +10,7 @@ use tracing::instrument; use v3::client::{DecodeTimings, PrefillTimings}; use v3::{ Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, - NextTokenChooserParameters, Request, StoppingCriteriaParameters, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest, }; #[derive(Debug, Clone)] @@ -84,12 +84,12 @@ impl ShardedClient { pub async fn filter_batch( &mut self, batch_id: u64, - request_ids: Vec, + updated_requests: Vec, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone()))) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 20630c1b0cd..3b61e46667f 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -506,6 +506,8 @@ pub enum InferError { TemplateError(#[from] minijinja::Error), #[error("Tool error: {0}")] ToolError(String), + #[error("Request could not be re-allocated: out of pages")] + OutOfPages, } impl InferError { @@ -517,6 +519,7 @@ impl InferError { InferError::IncompleteGeneration => "incomplete_generation", InferError::TemplateError(_) => "template_error", InferError::ToolError(_) => "tool_error", + InferError::OutOfPages => "out_of_pages", } } } diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 7467fd85997..811efb262ec 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -8,6 +8,12 @@ pub(crate) struct BlockAllocation { block_allocator: BlockAllocator, } +impl BlockAllocation { + pub(crate) fn len(&self) -> usize { + self.slots.len() + } +} + impl Drop for BlockAllocation { fn drop(&mut self) { self.block_allocator.free(self.blocks.clone()) @@ -83,6 +89,8 @@ async fn block_allocator_task( tokens, response_sender, } => { + // let tokens = 16; + // Apply window size let (required_blocks, repeats) = { let (tokens, repeats) = match window_size { diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 1522679445c..1ac06ae97c9 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -34,7 +34,7 @@ pub(crate) struct Entry { /// Block Allocation pub block_allocation: Option, /// Current length (in tokens) of the request (prompt tokens + generated_tokens) - pub current_length: u32 + pub current_length: u32, } /// Request Queue diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index bf52e69f72a..faa899ecd1d 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -10,7 +10,7 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; +use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest}; use text_generation_client::ClientError; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; @@ -288,7 +288,7 @@ async fn decode( // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); - filter_update_allocations(client, entries).await; + filter_update_allocations(entries).await; // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; @@ -323,7 +323,7 @@ async fn filter_batch( next_batch: Option, entries: &IntMap, ) -> Option { - let mut batch = next_batch?; + let batch = next_batch?; // No need to filter if batch.size as usize == entries.len() { @@ -331,11 +331,7 @@ async fn filter_batch( } let id = batch.id; - - // Retain only requests that are still in entries - batch.request_ids.retain(|id| entries.contains_key(id)); - - if batch.request_ids.is_empty() { + if entries.is_empty() { // All requests have been filtered out // Next batch is now empty // Clear it from the Python shards cache @@ -344,8 +340,24 @@ async fn filter_batch( None } else { // Filter Python shard cache + let updated_requests = entries + .iter() + .map(|(request_id, entry)| { + let (blocks, slots) = entry + .block_allocation + .as_ref() + .map(|alloc| (alloc.blocks.clone(), alloc.slots.clone())) + .unwrap_or((Vec::new(), Vec::new())); + UpdatedRequest { + id: *request_id, + blocks, + slots, + } + }) + .collect(); + // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, batch.request_ids).await.unwrap() + client.filter_batch(id, updated_requests).await.unwrap() } } @@ -379,32 +391,36 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - // let mut extend_entries = Vec::with_capacity(entries.len()); - // let mut finish_entries = Vec::with_capacity(entries.len()); - - // for (request_id, entry) in entries.into_iter() { - // tracing::info!("Allocation {}; Current Length: {}", entry.block_allocation.as_ref().unwrap().allocated_tokens, entry.current_length); - // - // if let Some(block_allocation) = &mut entry.block_allocation { - // tracing::info!("Allocation {:?}", block_allocation); - // - // if entry.current_length > block_allocation.allocated_tokens { - // // We need to add new blocks to this entry - // let remaining_tokens = block_allocation.total_tokens - entry.current_length; - // match block_allocation.extend(remaining_tokens).await { - // true => { - // - // }, - // false => { - // - // } - // } - // } - // } - // } +async fn filter_update_allocations(entries: &mut IntMap) { + entries.retain(|request_id, entry| { + if entry.block_allocation.is_none() { + return true; + } + + // We can unwrap since we already validated above that block_allocation is not None + let mut block_allocation = entry.block_allocation.as_ref().unwrap(); + + // Nothing to update + if entry.current_length <= block_allocation.len() as u32 { + return true; + } + + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::OutOfPages; + metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages"); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + + false + }); } /// Send responses through the `entry` response channel diff --git a/router/src/lib.rs b/router/src/lib.rs index 52c5aa461fd..b6902c497c1 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1085,8 +1085,6 @@ pub(crate) enum FinishReason { EndOfSequenceToken, #[schema(rename = "stop_sequence")] StopSequence, - #[schema(rename = "out_of_pages")] - OutOfPages } impl std::fmt::Display for FinishReason { @@ -1095,7 +1093,6 @@ impl std::fmt::Display for FinishReason { FinishReason::Length => write!(f, "length"), FinishReason::EndOfSequenceToken => write!(f, "eos_token"), FinishReason::StopSequence => write!(f, "stop_sequence"), - FinishReason::OutOfPages => write!(f, "out_of_pages"), } } } diff --git a/router/src/server.rs b/router/src/server.rs index 30479b0e4dc..9df33739c40 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1859,6 +1859,7 @@ impl From for (StatusCode, Json) { InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::OutOfPages => StatusCode::TOO_MANY_REQUESTS, }; ( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 2fe0f56e943..50a25a50bf0 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -158,7 +158,11 @@ def from_pb( ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["CausalLMBatch"]: + request_ids = [r.id for r in updated_requests] + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") if len(request_ids) == len(self): @@ -746,7 +750,7 @@ def generate_token( ), generated_text, top_tokens, - new_input_length + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index da5fa9dbf93..c4c1cf9a035 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -82,14 +82,10 @@ class FlashCausalLMBatch(Batch): # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode slot_indices: torch.Tensor - # list of length b of list of length s_i // block_size - block_tables: List[List[int]] # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: torch.Tensor - # list of length b of list of length s_i - slots: List[List[int]] # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots_tensor: torch.Tensor + slots: torch.Tensor max_seqlen: int @@ -183,7 +179,6 @@ def from_tokenized( max_blocks = 0 block_tables = [] - slots = [] flat_slots = [] # Parse batch @@ -253,7 +248,6 @@ def from_tokenized( len(flat_slots) + input_length, dtype=torch.int64, ) - slots.append(request_slots) flat_slots.extend(request_slots) slot_indices.append(request_slot_indices) @@ -353,7 +347,7 @@ def from_tokenized( top_n_tokens, device=device, dtype=torch.int64 ) - slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) + slots = torch.tensor(flat_slots, dtype=torch.int64, device=device) block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) @@ -370,10 +364,8 @@ def from_tokenized( cu_seqlen_prefill=cu_seqlen_prefill, prefill_cache_indices=prefill_cache_indices, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -405,11 +397,13 @@ def from_pb( return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": - if len(request_ids) == 0: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["FlashCausalLMBatch"]: + if len(updated_requests) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same - if len(request_ids) == len(self): + if len(updated_requests) == len(self): return self device = self.input_ids.device @@ -425,7 +419,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": requests = [] block_tables = [] - slots = [] flat_slots = [] all_input_ids = [] @@ -439,7 +432,9 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": num_blocks = 0 max_blocks = 0 - for i, request_id in enumerate(request_ids): + for i, request in enumerate(updated_requests): + request_id = request.id + idx = self.requests_idx_mapping[request_id] indices.append(idx) requests_idx_mapping[request_id] = i @@ -461,13 +456,12 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": top_n_tokens.append(self.top_n_tokens[idx]) - request_block_table = self.block_tables[idx] + request_block_table = request.blocks num_blocks += len(request_block_table) block_tables.append(request_block_table) # List of slots allocated for this request - request_slots = self.slots[idx] - slots.append(request_slots) + request_slots = request.slots # Index slot_indices.append(len(flat_slots) + request_input_length - 1) @@ -479,7 +473,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] - block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] @@ -487,10 +480,20 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": self.speculative_ids[indices] if self.speculative_ids is not None else None ) + # Create block_tables_tensor on CPU + block_tables_tensor = torch.zeros( + (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" + ) + for i, request_blocks in enumerate(block_tables): + block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) + # Allocate on GPU - slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) + slots = torch.tensor(flat_slots, dtype=torch.int64, device=device) slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) + # Move to GPU + block_tables_tensor = block_tables_tensor.to(device) + return type(self)( batch_id=self.batch_id, requests=requests, @@ -500,10 +503,8 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": cu_seqlen_prefill=None, prefill_cache_indices=None, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -538,7 +539,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch max_seqlen = 0 for b in batches: total_batch_size += len(b) - total_slots += len(b.slots_tensor) + total_slots += len(b.slots) num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 @@ -561,7 +562,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots_tensor = batches[0].slots_tensor.new_empty(total_slots) + slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( total_batch_size @@ -576,8 +577,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch total_batch_size, ) - slots = [] - block_tables = [] all_input_ids = [] input_lengths = [] @@ -606,7 +605,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots_tensor) + slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids @@ -614,7 +613,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - slots_tensor[slots_start_index:slots_end_index] = batch.slots_tensor + slots[slots_start_index:slots_end_index] = batch.slots all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] @@ -624,8 +623,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] - slots.extend(batch.slots) - block_tables.extend(batch.block_tables) all_input_ids.extend(batch.all_input_ids) input_lengths.extend(batch.input_lengths) @@ -640,7 +637,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch # Update cumulative_batch_size += len(batch) - cumulative_slots += len(batch.slots_tensor) + cumulative_slots += len(batch.slots) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, @@ -665,10 +662,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch cu_seqlen_prefill=None, prefill_cache_indices=None, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -969,7 +964,7 @@ def forward( cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots_tensor[batch.slot_indices] + slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1008,7 +1003,7 @@ def forward( cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots_tensor[batch.slot_indices] + slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1350,7 +1345,7 @@ def generate_token( ), generated_text, top_tokens, - input_length + n_accepted_ids + input_length + n_accepted_ids, ) generations.append(generation) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 44b2189921c..fd70ae5d149 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -214,7 +214,11 @@ def from_pb_processor( ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["IdeficsCausalLMBatch"]: + request_ids = [r.id for r in updated_requests] + # It deletes requests from the batch. For instance when client lost connection if len(request_ids) == 0: raise ValueError("Batch must have at least one request") @@ -829,7 +833,7 @@ def generate_token( ), generated_text, top_tokens, - new_input_length + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 8182eb46936..c8066aec4ba 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -195,7 +195,11 @@ def from_pb( max_tokens=max_tokens, ) - def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["MambaBatch"]: + request_ids = [r.id for r in updated_requests] + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") if len(request_ids) == len(self): @@ -775,7 +779,7 @@ def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, in ), generated_text, top_tokens, - new_input_length + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 74ea2dabc18..1e4f7c2e54d 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -166,7 +166,11 @@ def from_pb( ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["Seq2SeqLMBatch"]: + request_ids = [r.id for r in updated_requests] + if len(request_ids) == 0: raise ValueError("Batch must have at least one request") if len(request_ids) == len(self): diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 1c7a157a453..50c14862762 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -28,7 +28,7 @@ def from_pb( raise NotImplementedError @abstractmethod - def filter(self, request_ids: List[int]) -> "Batch": + def filter(self, updated_requests: List[generate_pb2.UpdatedRequest]) -> "Batch": raise NotImplementedError @classmethod diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index b1ccd140e36..bc51e732b99 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -122,8 +122,10 @@ def concatenate(cls, batches): return batch @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]): - batch = super().filter(request_ids) + def filter( + self, updated_requests: List[generate_pb2.UpdatedRequest] + ) -> Optional["VlmCausalLMBatch"]: + batch = super().filter(updated_requests) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 569b6925a0e..a66c19a0ce5 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -83,7 +83,7 @@ async def FilterBatch(self, request, context): batch = self.cache.pop(request.batch_id) if batch is None: raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - filtered_batch = batch.filter(request.request_ids) + filtered_batch = batch.filter(request.updated_requests) self.cache.set(filtered_batch) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) From 35f27cbcc1acfe51cd376935f31991ff35c53639 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 5 Jun 2024 18:47:16 +0200 Subject: [PATCH 03/18] working example --- router/src/infer/v3/block_allocator.rs | 72 +++++++++++++---- router/src/infer/v3/queue.rs | 16 ++-- router/src/infer/v3/scheduler.rs | 80 ++++++++++++------- .../models/flash_causal_lm.py | 3 - 4 files changed, 119 insertions(+), 52 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 811efb262ec..3e7cde893ed 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,10 +1,13 @@ -use std::cmp::min; +use std::cmp::{max, min}; +use thiserror::Error; use tokio::sync::{mpsc, oneshot}; #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { pub blocks: Vec, pub slots: Vec, + prompt_tokens: u32, + decode_tokens: u32, block_allocator: BlockAllocator, } @@ -12,6 +15,14 @@ impl BlockAllocation { pub(crate) fn len(&self) -> usize { self.slots.len() } + + pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> { + let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1); + self.block_allocator + .clone() + .extend(self, remaining_tokens) + .await + } } impl Drop for BlockAllocation { @@ -48,11 +59,16 @@ impl BlockAllocator { } } - pub(crate) async fn allocate(&self, tokens: u32) -> Option { + pub(crate) async fn allocate( + &self, + prompt_tokens: u32, + decode_tokens: u32, + ) -> Result { let (response_sender, response_receiver) = oneshot::channel(); self.block_allocator .send(BlockAllocatorCommand::Allocate { - tokens, + prompt_tokens, + decode_tokens, response_sender, }) .unwrap(); @@ -63,10 +79,32 @@ impl BlockAllocator { .map(|(blocks, slots)| BlockAllocation { blocks, slots, + prompt_tokens, + decode_tokens, block_allocator: self.clone(), }) } + pub(crate) async fn extend( + &self, + block_allocation: &mut BlockAllocation, + tokens: u32, + ) -> Result<(), AllocationError> { + let (response_sender, response_receiver) = oneshot::channel(); + self.block_allocator + .send(BlockAllocatorCommand::Allocate { + prompt_tokens: 0, + decode_tokens: tokens, + response_sender, + }) + .unwrap(); + + let (blocks, slots) = response_receiver.await.unwrap()?; + block_allocation.blocks.extend(blocks); + block_allocation.slots.extend(slots); + Ok(()) + } + pub(crate) fn free(&self, blocks: Vec) { self.block_allocator .send(BlockAllocatorCommand::Free { blocks }) @@ -86,10 +124,12 @@ async fn block_allocator_task( match cmd { BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), BlockAllocatorCommand::Allocate { - tokens, + prompt_tokens, + decode_tokens, response_sender, } => { - // let tokens = 16; + let decode_tokens = min(decode_tokens, block_size); + let tokens = prompt_tokens + decode_tokens; // Apply window size let (required_blocks, repeats) = { @@ -106,9 +146,8 @@ async fn block_allocator_task( (required_blocks, repeats) }; - let tokens = tokens as usize; let allocation = if required_blocks > free_blocks.len() as u32 { - None + Err(AllocationError::NotEnoughPages) } else { let blocks = free_blocks.split_off(free_blocks.len() - required_blocks as usize); @@ -116,15 +155,12 @@ async fn block_allocator_task( (required_blocks * block_size * repeats as u32) as usize, ); - 'slots: for block_id in blocks.repeat(repeats).iter() { + for block_id in blocks.repeat(repeats).iter() { for s in (block_id * block_size)..((block_id + 1) * block_size) { slots.push(s); - if slots.len() == tokens { - break 'slots; - } } } - Some((blocks, slots)) + Ok((blocks, slots)) }; response_sender.send(allocation).unwrap(); } @@ -138,7 +174,15 @@ enum BlockAllocatorCommand { blocks: Vec, }, Allocate { - tokens: u32, - response_sender: oneshot::Sender, Vec)>>, + prompt_tokens: u32, + decode_tokens: u32, + #[allow(clippy::type_complexity)] + response_sender: oneshot::Sender, Vec), AllocationError>>, }, } + +#[derive(Error, Debug)] +pub enum AllocationError { + #[error("Not enough pages")] + NotEnoughPages, +} diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 1ac06ae97c9..9a7b1084bbe 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -295,20 +295,20 @@ impl State { break; } - let tokens = entry.request.input_length - + entry.request.stopping_parameters.max_new_tokens - + self.speculate - - 1; - - match block_allocator.allocate(tokens).await { - None => { + let decode_tokens = + entry.request.stopping_parameters.max_new_tokens + self.speculate - 1; + match block_allocator + .allocate(entry.request.input_length, decode_tokens) + .await + { + Err(_) => { // Entry is over budget // Add it back to the front tracing::debug!("Over budget: not enough free blocks"); self.entries.push_front((id, entry)); break 'entry_loop; } - Some(block_allocation) => { + Ok(block_allocation) => { tracing::debug!("Allocation: {block_allocation:?}"); max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); Some(block_allocation) diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index faa899ecd1d..b76c5c50eab 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -247,7 +247,7 @@ async fn prefill( filter_send_generations(generations, entries); // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; + let next_batch = filter_batch(client, next_batch, entries, false).await; metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); @@ -288,10 +288,10 @@ async fn decode( // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); - filter_update_allocations(entries).await; + let updated = filter_update_allocations(entries).await; // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; + let next_batch = filter_batch(client, next_batch, entries, updated).await; if let Some(concat_duration) = timings.concat { metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); @@ -322,11 +322,12 @@ async fn filter_batch( client: &mut ShardedClient, next_batch: Option, entries: &IntMap, + force_update: bool, ) -> Option { let batch = next_batch?; // No need to filter - if batch.size as usize == entries.len() { + if batch.size as usize == entries.len() && !force_update { return Some(batch); } @@ -348,6 +349,7 @@ async fn filter_batch( .as_ref() .map(|alloc| (alloc.blocks.clone(), alloc.slots.clone())) .unwrap_or((Vec::new(), Vec::new())); + UpdatedRequest { id: *request_id, blocks, @@ -393,34 +395,58 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - entries.retain(|request_id, entry| { - if entry.block_allocation.is_none() { - return true; - } +async fn filter_update_allocations(entries: &mut IntMap) -> bool { + let ids: Vec = entries + .iter() + .filter_map(|(id, entry)| { + entry + .block_allocation + .as_ref() + .map(|block_allocation| { + if entry.current_length > block_allocation.len() as u32 { + // We need to re-allocate + Some(*id) + } else { + None + } + }) + .unwrap_or(None) + }) + .collect(); - // We can unwrap since we already validated above that block_allocation is not None - let mut block_allocation = entry.block_allocation.as_ref().unwrap(); + for id in ids.iter() { + // Get entry + // We can `expect` here as the request id should always be in the entries + let extension = { + let entry = entries + .get_mut(id) + .expect("ID not found in entries. This is a bug."); + entry + .block_allocation + .as_mut() + .unwrap() + .extend(entry.current_length) + .await + }; - // Nothing to update - if entry.current_length <= block_allocation.len() as u32 { - return true; - } + if extension.is_err() { + let entry = entries + .remove(id) + .expect("ID not found in entries. This is a bug."); - // Create and enter a span to link this function back to the entry - let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); - let err = InferError::OutOfPages; - metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages"); - tracing::error!("{err}"); + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::OutOfPages; + metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages"); + tracing::error!("{err}"); - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Err(err)) - .unwrap_or(()); + // unwrap_or is valid here as we don't care if the receiver is gone. + entry.response_tx.send(Err(err)).unwrap_or(()); + } + } - false - }); + // If ids is not empty, we need to update + !ids.is_empty() } /// Send responses through the `entry` response channel diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c4c1cf9a035..cd7c1f0f2ed 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -402,9 +402,6 @@ def filter( ) -> Optional["FlashCausalLMBatch"]: if len(updated_requests) == 0: raise ValueError("Batch must have at least one request") - # We assume that if len(requests) == len(self) then the requests are the same - if len(updated_requests) == len(self): - return self device = self.input_ids.device From 51fa6068758a9c56ea12a6259811ca9060687fa5 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 5 Jun 2024 21:32:46 +0200 Subject: [PATCH 04/18] fix --- server/text_generation_server/models/vlm_causal_lm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index bc51e732b99..da7de2d3063 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -230,7 +230,7 @@ def forward( cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots_tensor[batch.slot_indices] + slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -269,7 +269,7 @@ def forward( cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots_tensor[batch.slot_indices] + slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices From 3c596983ba489fcf03971cb096ee05747617f09d Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 6 Jun 2024 10:18:26 +0200 Subject: [PATCH 05/18] fix python tests --- server/tests/models/test_bloom.py | 17 ++++++++++++++--- server/tests/models/test_causal_lm.py | 24 +++++++++++++++++++++--- server/tests/models/test_seq2seq_lm.py | 17 ++++++++++++++--- 3 files changed, 49 insertions(+), 9 deletions(-) diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 32ee6686b6b..0daa5f418eb 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -197,7 +197,9 @@ def test_causal_lm_generate_token_completion_multi( # Copy stopping_criterias before filtering stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() - next_batch = next_batch.filter([next_batch.requests[0].id]) + next_batch = next_batch.filter( + [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + ) for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 @@ -306,7 +308,14 @@ def test_batch_concatenate( ) next_batch = next_batch.filter( - [next_batch.requests[0].id, next_batch.requests[1].id] + [ + generate_pb2.UpdatedRequest( + id=next_batch.requests[0].id, blocks=[], slots=[] + ), + generate_pb2.UpdatedRequest( + id=next_batch.requests[1].id, blocks=[], slots=[] + ), + ] ) for _ in range( @@ -330,7 +339,9 @@ def test_batch_concatenate( == default_bloom_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[1].id]) + next_batch = next_batch.filter( + [generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])] + ) for _ in range( default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 6e6463bc948..547da81f76c 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -198,7 +198,9 @@ def test_causal_lm_generate_token_completion_multi( default_multi_requests_causal_lm_batch.stopping_criterias.copy() ) - next_batch = next_batch.filter([next_batch.requests[0].id]) + next_batch = next_batch.filter( + [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + ) for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 @@ -306,7 +308,14 @@ def test_batch_concatenate( ) next_batch = next_batch.filter( - [next_batch.requests[0].id, next_batch.requests[1].id] + [ + generate_pb2.UpdatedRequest( + id=next_batch.requests[0].id, blocks=[], slots=[] + ), + generate_pb2.UpdatedRequest( + id=next_batch.requests[1].id, blocks=[], slots=[] + ), + ] ) for _ in range( @@ -328,7 +337,16 @@ def test_batch_concatenate( == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[1].id]) + next_batch = next_batch.filter( + [ + generate_pb2.UpdatedRequest( + id=next_batch.requests[0].id, blocks=[], slots=[] + ), + generate_pb2.UpdatedRequest( + id=next_batch.requests[1].id, blocks=[], slots=[] + ), + ] + ) for _ in range( default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 943c3b0820d..17b5fa50c41 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -206,7 +206,9 @@ def test_seq2seq_lm_generate_token_completion_multi( ) assert generations[1].generated_text.generated_tokens == 5 - next_batch = next_batch.filter([next_batch.requests[0].id]) + next_batch = next_batch.filter( + [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) @@ -340,7 +342,14 @@ def test_batch_concatenate( assert generations[2].generated_text.generated_tokens == 5 next_batch = next_batch.filter( - [next_batch.requests[0].id, next_batch.requests[1].id] + [ + generate_pb2.UpdatedRequest( + id=next_batch.requests[0].id, blocks=[], slots=[] + ), + generate_pb2.UpdatedRequest( + id=next_batch.requests[1].id, blocks=[], slots=[] + ), + ] ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) @@ -351,7 +360,9 @@ def test_batch_concatenate( assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].generated_text.generated_tokens == 7 - next_batch = next_batch.filter([next_batch.requests[1].id]) + next_batch = next_batch.filter( + [generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])] + ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None From 298bf31e69dc934060f951e4287d54bd9abaa92b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 7 Jun 2024 11:26:17 +0200 Subject: [PATCH 06/18] add terminated_generations --- proto/v3/generate.proto | 16 ++++++-- router/client/src/v3/client.rs | 6 ++- router/client/src/v3/mod.rs | 4 +- router/client/src/v3/sharded_client.rs | 15 ++++++-- router/src/infer/v3/block_allocator.rs | 8 ++-- router/src/infer/v3/queue.rs | 27 +++----------- router/src/infer/v3/scheduler.rs | 20 ++++++---- router/src/lib.rs | 3 ++ .../models/causal_lm.py | 2 +- .../models/flash_causal_lm.py | 37 ++++++++++++++++--- .../models/idefics_causal_lm.py | 2 +- server/text_generation_server/models/mamba.py | 2 +- .../models/seq2seq_lm.py | 2 +- server/text_generation_server/models/types.py | 13 +++++-- .../models/vlm_causal_lm.py | 2 +- server/text_generation_server/server.py | 8 +++- 16 files changed, 107 insertions(+), 60 deletions(-) diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 192cd111bb7..3c9b1d7171d 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -164,6 +164,7 @@ enum FinishReason { FINISH_REASON_LENGTH = 0; FINISH_REASON_EOS_TOKEN = 1; FINISH_REASON_STOP_SEQUENCE = 2; + FINISH_REASON_TERMINATED = 3; } message GeneratedText { @@ -198,11 +199,11 @@ message Generation { optional GeneratedText generated_text = 4; /// Top tokens repeated Tokens top_tokens = 5; - /// Current length of the request: prompt tokens + number of generated tokens until this point - uint32 current_length = 6; + /// Current length of the cache: prompt tokens + number of generated tokens until this point + uint32 cache_length = 6; } -message UpdatedRequest { +message KeptRequest { /// Request ID uint64 id = 1; /// Paged attention blocks @@ -211,16 +212,23 @@ message UpdatedRequest { repeated uint32 slots = 3; } +/// kept_requests + terminated_request_ids might not cover all requests from the +/// cached batch as some requests can be filtered out without requiring to generate text +/// for example if the client dropped its connection to the router message FilterBatchRequest { /// Batch ID uint64 batch_id = 1; /// Requests to keep - repeated UpdatedRequest updated_requests = 2; + repeated KeptRequest kept_requests = 2; + /// Requests to terminate and generate text for + repeated uint64 terminated_request_ids = 3; } message FilterBatchResponse { /// Filtered Batch (cached) CachedBatch batch = 1; + /// Terminated generations + repeated GeneratedText terminated_generations = 2; } diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 8cefd3137b2..90f43270e0f 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -90,11 +90,13 @@ impl Client { pub async fn filter_batch( &mut self, batch_id: u64, - updated_requests: Vec, + kept_requests: Vec, + terminated_request_ids: Vec, ) -> Result> { let request = tonic::Request::new(FilterBatchRequest { batch_id, - updated_requests, + kept_requests, + terminated_request_ids, }) .inject_context(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs index df2bb380734..ea7486eec0a 100644 --- a/router/client/src/v3/mod.rs +++ b/router/client/src/v3/mod.rs @@ -7,7 +7,7 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v3::{ input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, - HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, - StoppingCriteriaParameters, Tokens, UpdatedRequest, + HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, }; pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index a066176ce5e..e1b35a212ab 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -9,8 +9,8 @@ use tonic::transport::Uri; use tracing::instrument; use v3::client::{DecodeTimings, PrefillTimings}; use v3::{ - Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, - NextTokenChooserParameters, Request, StoppingCriteriaParameters, UpdatedRequest, + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, KeptRequest, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; #[derive(Debug, Clone)] @@ -84,12 +84,19 @@ impl ShardedClient { pub async fn filter_batch( &mut self, batch_id: u64, - updated_requests: Vec, + kept_requests: Vec, + terminated_request_ids: Vec, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.filter_batch(batch_id, updated_requests.clone()))) + .map(|client| { + Box::pin(client.filter_batch( + batch_id, + kept_requests.clone(), + terminated_request_ids.clone(), + )) + }) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 3e7cde893ed..db6180345d1 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,4 +1,4 @@ -use std::cmp::{max, min}; +use std::cmp::min; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; @@ -16,8 +16,9 @@ impl BlockAllocation { self.slots.len() } - pub(crate) async fn extend(&mut self, current_length: u32) -> Result<(), AllocationError> { - let remaining_tokens = max(self.prompt_tokens + self.decode_tokens - current_length, 1); + pub(crate) async fn extend(&mut self, cache_length: u32) -> Result<(), AllocationError> { + let remaining_tokens = + (self.prompt_tokens + self.decode_tokens).saturating_sub(cache_length); self.block_allocator .clone() .extend(self, remaining_tokens) @@ -131,6 +132,7 @@ async fn block_allocator_task( let decode_tokens = min(decode_tokens, block_size); let tokens = prompt_tokens + decode_tokens; + // FIXME: window size is not working // Apply window size let (required_blocks, repeats) = { let (tokens, repeats) = match window_size { diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 9a7b1084bbe..cbe1fbd0d83 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -5,7 +5,7 @@ use crate::validation::{ ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::{max, min}; +use std::cmp::max; use std::collections::VecDeque; use text_generation_client::v3::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, @@ -33,8 +33,8 @@ pub(crate) struct Entry { pub batch_time: Option, /// Block Allocation pub block_allocation: Option, - /// Current length (in tokens) of the request (prompt tokens + generated_tokens) - pub current_length: u32, + /// Cache length (in tokens) of the request (prompt tokens + generated_tokens) + pub cache_length: u32, } /// Request Queue @@ -164,9 +164,6 @@ struct State { /// Paged Attention block size block_size: u32, - /// Sliding window - window_size: Option, - /// Speculation amount speculate: u32, @@ -190,7 +187,6 @@ impl State { next_id: 0, next_batch_id: 0, block_size, - window_size, speculate, block_allocator, } @@ -276,18 +272,7 @@ impl State { } Some(block_allocator) => { prefill_tokens += entry.request.input_length; - let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, - ), - }; - decode_tokens += max_new_tokens; - - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { + if prefill_tokens > prefill_token_budget { // Entry is over budget // Add it back to the front tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); @@ -296,7 +281,7 @@ impl State { } let decode_tokens = - entry.request.stopping_parameters.max_new_tokens + self.speculate - 1; + entry.request.stopping_parameters.max_new_tokens + self.speculate; match block_allocator .allocate(entry.request.input_length, decode_tokens) .await @@ -500,7 +485,7 @@ mod tests { queue_time: Instant::now(), batch_time: None, block_allocation: None, - current_length: 0, + cache_length: 0, }; (entry, receiver_tx) } diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index b76c5c50eab..fa1a9ac7c1c 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -10,7 +10,7 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient, UpdatedRequest}; +use text_generation_client::v3::{Batch, CachedBatch, Generation, KeptRequest, ShardedClient}; use text_generation_client::ClientError; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; @@ -88,7 +88,7 @@ impl Scheduler for SchedulerV3 { queue_time: Instant::now(), batch_time: None, block_allocation: None, - current_length: input_length, + cache_length: 0, }); // Notify the background task that we have a new entry in the queue that needs @@ -350,7 +350,7 @@ async fn filter_batch( .map(|alloc| (alloc.blocks.clone(), alloc.slots.clone())) .unwrap_or((Vec::new(), Vec::new())); - UpdatedRequest { + KeptRequest { id: *request_id, blocks, slots, @@ -359,7 +359,10 @@ async fn filter_batch( .collect(); // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, updated_requests).await.unwrap() + client + .filter_batch(id, updated_requests, Vec::new()) + .await + .unwrap() } } @@ -374,7 +377,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap) -> bool { .block_allocation .as_ref() .map(|block_allocation| { - if entry.current_length > block_allocation.len() as u32 { + if entry.cache_length > block_allocation.len() as u32 { // We need to re-allocate Some(*id) } else { @@ -424,8 +427,8 @@ async fn filter_update_allocations(entries: &mut IntMap) -> bool { entry .block_allocation .as_mut() - .unwrap() - .extend(entry.current_length) + .expect("We checked that the block allocation exists above") + .extend(entry.cache_length) .await }; @@ -563,6 +566,7 @@ impl From for GeneratedText { let v3_finish_reason = text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); let finish_reason = match v3_finish_reason { + text_generation_client::v3::FinishReason::Terminated => FinishReason::OutOfResources, text_generation_client::v3::FinishReason::Length => FinishReason::Length, text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, diff --git a/router/src/lib.rs b/router/src/lib.rs index b6902c497c1..2f115aba94f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1085,6 +1085,8 @@ pub(crate) enum FinishReason { EndOfSequenceToken, #[schema(rename = "stop_sequence")] StopSequence, + #[schema(rename = "out_of_resources")] + OutOfResources, } impl std::fmt::Display for FinishReason { @@ -1093,6 +1095,7 @@ impl std::fmt::Display for FinishReason { FinishReason::Length => write!(f, "length"), FinishReason::EndOfSequenceToken => write!(f, "eos_token"), FinishReason::StopSequence => write!(f, "stop_sequence"), + FinishReason::OutOfResources => write!(f, "out_of_resources"), } } } diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 50a25a50bf0..40a4f100221 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -159,7 +159,7 @@ def from_pb( @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] + self, updated_requests: List[generate_pb2.KeptRequest] ) -> Optional["CausalLMBatch"]: request_ids = [r.id for r in updated_requests] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cd7c1f0f2ed..0bd9357f49a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -398,11 +398,37 @@ def from_pb( @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] - ) -> Optional["FlashCausalLMBatch"]: - if len(updated_requests) == 0: + self, + model: "FlashCausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.GeneratedText]]: + if len(kept_requests) == 0: raise ValueError("Batch must have at least one request") + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + do_sample = self.next_token_chooser.do_sample[idx] + seed = self.next_token_chooser.seeds[idx] + + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + generate_pb2.FINISH_REASON_TERMINATED, + seed if do_sample else None, + ) + terminated_generations.append(generated_text) + device = self.input_ids.device # New values after filtering @@ -429,7 +455,7 @@ def filter( num_blocks = 0 max_blocks = 0 - for i, request in enumerate(updated_requests): + for i, request in enumerate(kept_requests): request_id = request.id idx = self.requests_idx_mapping[request_id] @@ -491,7 +517,7 @@ def filter( # Move to GPU block_tables_tensor = block_tables_tensor.to(device) - return type(self)( + filtered_batch = type(self)( batch_id=self.batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, @@ -520,6 +546,7 @@ def filter( max_blocks=max_blocks, speculative_ids=speculative_ids, ) + return filtered_batch, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index fd70ae5d149..495a47e5e22 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -215,7 +215,7 @@ def from_pb_processor( @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] + self, updated_requests: List[generate_pb2.KeptRequest] ) -> Optional["IdeficsCausalLMBatch"]: request_ids = [r.id for r in updated_requests] diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index c8066aec4ba..0340ca5519e 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -196,7 +196,7 @@ def from_pb( ) def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] + self, updated_requests: List[generate_pb2.KeptRequest] ) -> Optional["MambaBatch"]: request_ids = [r.id for r in updated_requests] diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 1e4f7c2e54d..77407118d44 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -167,7 +167,7 @@ def from_pb( @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] + self, updated_requests: List[generate_pb2.KeptRequest] ) -> Optional["Seq2SeqLMBatch"]: request_ids = [r.id for r in updated_requests] diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 50c14862762..c19f804eada 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple from transformers import PreTrainedTokenizerBase @@ -28,7 +28,12 @@ def from_pb( raise NotImplementedError @abstractmethod - def filter(self, updated_requests: List[generate_pb2.UpdatedRequest]) -> "Batch": + def filter( + self, + model, + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple["Batch", List[generate_pb2.GeneratedText]]: raise NotImplementedError @classmethod @@ -84,7 +89,7 @@ class Generation: generated_text: Optional[GeneratedText] # Optional for now, since it's not yet supported for every model. top_tokens: Optional[List[Tokens]] - current_length: int + cache_length: int def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( @@ -101,5 +106,5 @@ def to_pb(self) -> generate_pb2.Generation: if self.top_tokens is not None else None ), - current_length=self.current_length, + cache_length=self.cache_length, ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index da7de2d3063..e3d0bee8690 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -123,7 +123,7 @@ def concatenate(cls, batches): @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.UpdatedRequest] + self, updated_requests: List[generate_pb2.KeptRequest] ) -> Optional["VlmCausalLMBatch"]: batch = super().filter(updated_requests) batch.pixel_values = None diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index a66c19a0ce5..86df66e7eac 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -83,10 +83,14 @@ async def FilterBatch(self, request, context): batch = self.cache.pop(request.batch_id) if batch is None: raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - filtered_batch = batch.filter(request.updated_requests) + filtered_batch, terminated_generations = batch.filter( + self.model, request.kept_requests, request.terminated_request_ids + ) self.cache.set(filtered_batch) - return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + return generate_pb2.FilterBatchResponse( + batch=filtered_batch.to_pb(), terminated_generations=terminated_generations + ) async def Warmup(self, request, context): if self.quantize in {"exl2", "gptq"}: From 713d70b443af83da5fd3a96e3aa9274de31d652b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 7 Jun 2024 13:39:42 +0200 Subject: [PATCH 07/18] re-working logic, wip --- router/src/infer/v3/block_allocator.rs | 235 ++++++++++++------------- router/src/infer/v3/queue.rs | 1 - router/src/infer/v3/scheduler.rs | 3 +- 3 files changed, 111 insertions(+), 128 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index db6180345d1..0d6c7cfa625 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,6 +1,6 @@ use std::cmp::min; +use std::sync::{Arc, Mutex}; use thiserror::Error; -use tokio::sync::{mpsc, oneshot}; #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { @@ -16,13 +16,23 @@ impl BlockAllocation { self.slots.len() } - pub(crate) async fn extend(&mut self, cache_length: u32) -> Result<(), AllocationError> { - let remaining_tokens = - (self.prompt_tokens + self.decode_tokens).saturating_sub(cache_length); - self.block_allocator - .clone() - .extend(self, remaining_tokens) - .await + pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { + let (block, slots) = self.block_allocator.allocate_block()?; + + match self.block_allocator.window_size { + None => { + self.blocks.push(block); + self.slots.extend(slots); + } + Some(window_size) => { + if self.len() as u32 > window_size { + let total_tokens = self.prompt_tokens + self.decode_tokens; + + let repeats = (total_tokens + window_size - 1) / window_size; + } + } + } + Ok(()) } } @@ -34,8 +44,9 @@ impl Drop for BlockAllocation { #[derive(Debug, Clone)] pub(crate) struct BlockAllocator { - /// Channel to communicate with the background task - block_allocator: mpsc::UnboundedSender, + free_blocks: Arc>>, + block_size: u32, + window_size: Option, } impl BlockAllocator { @@ -44,39 +55,105 @@ impl BlockAllocator { block_size: u32, window_size: Option, ) -> Self { - // Create channel - let (sender, receiver) = mpsc::unbounded_channel(); + let blocks = max_batch_total_tokens / block_size; + // Block 0 is reserved for health checks + let free_blocks: Vec = (1..blocks).collect(); - // Launch background queue task - tokio::spawn(block_allocator_task( - max_batch_total_tokens / block_size, + Self { + free_blocks: Arc::new(Mutex::new(free_blocks)), block_size, window_size, - receiver, - )); + } + } - Self { - block_allocator: sender, + fn allocate_block(&self) -> Result<(u32, Vec), AllocationError> { + let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); + + if free_blocks.is_empty() { + return Err(AllocationError::NotEnoughPages); } + + let block_id = free_blocks.pop().unwrap(); + let slots = ((block_id * self.block_size)..((block_id + 1) * self.block_size)).collect(); + Ok((block_id, slots)) } - pub(crate) async fn allocate( + /// For prompt tokens, we allocate enough blocks to cover all tokens + /// For decode tokens, we allocate block by block + /// + /// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots + fn allocate( &self, prompt_tokens: u32, decode_tokens: u32, - ) -> Result { - let (response_sender, response_receiver) = oneshot::channel(); - self.block_allocator - .send(BlockAllocatorCommand::Allocate { - prompt_tokens, - decode_tokens, - response_sender, - }) - .unwrap(); + ) -> Result<(Vec, Vec), AllocationError> { + // let decode_tokens = min(decode_tokens, self.block_size); + // let tokens = prompt_tokens + decode_tokens; + + let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size; + // prompt blocks + a single block for decode + let required_blocks = required_prompt_blocks + 1; + + let (required_blocks, repeats) = match self.window_size { + // Nothing to do + None => (required_blocks, 1), + Some(window_size) => { + // Number of blocks needed for this window size + let window_size_required_blocks = (window_size + self.block_size - 1) / self.block_size; + // Number of times we will need to repeat blocks to cover the required allocation + let repeats = (required_blocks + window_size_required_blocks -1) / window_size_required_blocks; + let required_blocks = min(required_blocks, window_size_required_blocks); + + (required_blocks, repeats) + } + }; + + + /// if prompt + decode < window size => do nothing + /// if prompt + decode > window size => do normal until we reach window size then + + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match self.window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + self.block_size - 1) / self.block_size; + (required_blocks, repeats) + }; + + let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); + + if required_blocks > free_blocks.len() as u32 { + Err(AllocationError::NotEnoughPages) + } else { + let n_free_blocks = free_blocks.len(); + let blocks = + free_blocks.split_off(n_free_blocks - required_blocks as usize); + let mut slots = Vec::with_capacity( + (required_blocks * self.block_size * repeats as u32) as usize, + ); + + for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + } + } + Ok((blocks, slots)) + } + } - response_receiver - .await - .unwrap() + pub(crate) fn block_allocation( + &self, + prompt_tokens: u32, + decode_tokens: u32, + ) -> Result { + self.allocate_inner(prompt_tokens, decode_tokens) .map(|(blocks, slots)| BlockAllocation { blocks, slots, @@ -86,103 +163,11 @@ impl BlockAllocator { }) } - pub(crate) async fn extend( - &self, - block_allocation: &mut BlockAllocation, - tokens: u32, - ) -> Result<(), AllocationError> { - let (response_sender, response_receiver) = oneshot::channel(); - self.block_allocator - .send(BlockAllocatorCommand::Allocate { - prompt_tokens: 0, - decode_tokens: tokens, - response_sender, - }) - .unwrap(); - - let (blocks, slots) = response_receiver.await.unwrap()?; - block_allocation.blocks.extend(blocks); - block_allocation.slots.extend(slots); - Ok(()) - } - pub(crate) fn free(&self, blocks: Vec) { - self.block_allocator - .send(BlockAllocatorCommand::Free { blocks }) - .unwrap(); - } -} - -async fn block_allocator_task( - blocks: u32, - block_size: u32, - window_size: Option, - mut receiver: mpsc::UnboundedReceiver, -) { - // Block 0 is reserved for health checks - let mut free_blocks: Vec = (1..blocks).collect(); - while let Some(cmd) = receiver.recv().await { - match cmd { - BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), - BlockAllocatorCommand::Allocate { - prompt_tokens, - decode_tokens, - response_sender, - } => { - let decode_tokens = min(decode_tokens, block_size); - let tokens = prompt_tokens + decode_tokens; - - // FIXME: window size is not working - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + block_size - 1) / block_size; - (required_blocks, repeats) - }; - - let allocation = if required_blocks > free_blocks.len() as u32 { - Err(AllocationError::NotEnoughPages) - } else { - let blocks = - free_blocks.split_off(free_blocks.len() - required_blocks as usize); - let mut slots = Vec::with_capacity( - (required_blocks * block_size * repeats as u32) as usize, - ); - - for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * block_size)..((block_id + 1) * block_size) { - slots.push(s); - } - } - Ok((blocks, slots)) - }; - response_sender.send(allocation).unwrap(); - } - } + self.free_blocks.lock().expect("Lock could not be acquired. This is a bug.").extend(blocks) } } -#[derive(Debug)] -enum BlockAllocatorCommand { - Free { - blocks: Vec, - }, - Allocate { - prompt_tokens: u32, - decode_tokens: u32, - #[allow(clippy::type_complexity)] - response_sender: oneshot::Sender, Vec), AllocationError>>, - }, -} - #[derive(Error, Debug)] pub enum AllocationError { #[error("Not enough pages")] diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index cbe1fbd0d83..14e67fffed5 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -284,7 +284,6 @@ impl State { entry.request.stopping_parameters.max_new_tokens + self.speculate; match block_allocator .allocate(entry.request.input_length, decode_tokens) - .await { Err(_) => { // Entry is over budget diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index fa1a9ac7c1c..0913f495b33 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -428,8 +428,7 @@ async fn filter_update_allocations(entries: &mut IntMap) -> bool { .block_allocation .as_mut() .expect("We checked that the block allocation exists above") - .extend(entry.cache_length) - .await + .extend() }; if extension.is_err() { From 6983ec9537ecb739dc9c57fa8e17a0a82bfaec1e Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 10 Jun 2024 11:44:50 +0200 Subject: [PATCH 08/18] small refactor --- router/src/infer/v3/block_allocator.rs | 144 +++++++++++++------------ router/src/infer/v3/queue.rs | 8 +- router/src/infer/v3/scheduler.rs | 2 +- 3 files changed, 78 insertions(+), 76 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 0d6c7cfa625..a084a5056e9 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,44 +1,55 @@ -use std::cmp::min; use std::sync::{Arc, Mutex}; use thiserror::Error; #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { - pub blocks: Vec, - pub slots: Vec, - prompt_tokens: u32, - decode_tokens: u32, + allocated_blocks: Vec, + allocated_slots: Vec, + required_blocks: usize, + required_slots: usize, block_allocator: BlockAllocator, } impl BlockAllocation { pub(crate) fn len(&self) -> usize { - self.slots.len() + self.allocated_slots.len() } - pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { - let (block, slots) = self.block_allocator.allocate_block()?; + pub(crate) fn blocks(&self) -> &[u32] { + &self.allocated_blocks + } - match self.block_allocator.window_size { - None => { - self.blocks.push(block); - self.slots.extend(slots); - } - Some(window_size) => { - if self.len() as u32 > window_size { - let total_tokens = self.prompt_tokens + self.decode_tokens; + pub(crate) fn slots(&self) -> &[u32] { + &self.allocated_slots + } - let repeats = (total_tokens + window_size - 1) / window_size; - } + pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { + let (block, slots) = self.block_allocator.allocate_block()?; + // Add block and slots to current allocation + self.allocated_blocks.push(block); + self.allocated_slots.extend(slots); + + if let Some(window_size) = self.block_allocator.window_size { + // if we have more slots than the window size, + // we will never need to re-allocate and we can just repeat the blocks/slots + let window_size = window_size as usize; + if self.len() > window_size { + let repeats = (self.required_slots + window_size - 1) / window_size; + self.allocated_blocks = self.allocated_blocks.repeat(repeats); + self.allocated_blocks.truncate(self.required_blocks); + self.allocated_slots = self.allocated_slots.repeat(repeats); + self.allocated_slots.truncate(self.required_slots); } } + Ok(()) } } impl Drop for BlockAllocation { fn drop(&mut self) { - self.block_allocator.free(self.blocks.clone()) + let allocated_blocks = std::mem::take(&mut self.allocated_blocks); + self.block_allocator.free(allocated_blocks) } } @@ -82,85 +93,76 @@ impl BlockAllocator { /// For decode tokens, we allocate block by block /// /// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots - fn allocate( + pub(crate) fn block_allocation( &self, prompt_tokens: u32, decode_tokens: u32, - ) -> Result<(Vec, Vec), AllocationError> { - // let decode_tokens = min(decode_tokens, self.block_size); - // let tokens = prompt_tokens + decode_tokens; - + ) -> Result { let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size; // prompt blocks + a single block for decode let required_blocks = required_prompt_blocks + 1; - let (required_blocks, repeats) = match self.window_size { + let (clipped_required_blocks, repeats) = match self.window_size { // Nothing to do None => (required_blocks, 1), Some(window_size) => { - // Number of blocks needed for this window size - let window_size_required_blocks = (window_size + self.block_size - 1) / self.block_size; - // Number of times we will need to repeat blocks to cover the required allocation - let repeats = (required_blocks + window_size_required_blocks -1) / window_size_required_blocks; - let required_blocks = min(required_blocks, window_size_required_blocks); - - (required_blocks, repeats) + // Number of blocks for this window size + let window_size_blocks = (window_size + self.block_size - 1) / self.block_size; + + if required_blocks > window_size_blocks { + // Number of times we will need to repeat blocks to cover the required allocation + let repeats = (required_blocks + window_size_blocks - 1) / window_size_blocks; + (window_size_blocks, repeats) + } else { + (required_blocks, 1) + } } }; - - /// if prompt + decode < window size => do nothing - /// if prompt + decode > window size => do normal until we reach window size then - - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match self.window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + self.block_size - 1) / self.block_size; - (required_blocks, repeats) - }; + let repeats = repeats as usize; + let required_blocks = required_blocks as usize; + let clipped_required_blocks = clipped_required_blocks as usize; let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); - if required_blocks > free_blocks.len() as u32 { + if clipped_required_blocks > free_blocks.len() { Err(AllocationError::NotEnoughPages) } else { let n_free_blocks = free_blocks.len(); - let blocks = - free_blocks.split_off(n_free_blocks - required_blocks as usize); - let mut slots = Vec::with_capacity( - (required_blocks * self.block_size * repeats as u32) as usize, + let allocated_blocks = + free_blocks.split_off(n_free_blocks - clipped_required_blocks); + + let allocated_blocks = if repeats != 1 { + let mut allocated_blocks = allocated_blocks.repeat(repeats); + allocated_blocks.truncate(required_blocks); + allocated_blocks + } else { + allocated_blocks + }; + + let mut allocated_slots = Vec::with_capacity( + allocated_blocks.len() * self.block_size as usize * repeats, ); - for block_id in blocks.repeat(repeats).iter() { + let required_slots = (prompt_tokens + decode_tokens) as usize; + + 'slots: for block_id in allocated_blocks.iter() { for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { - slots.push(s); + allocated_slots.push(s); + if allocated_slots.len() > required_slots { + break 'slots; + } } } - Ok((blocks, slots)) - } - } - pub(crate) fn block_allocation( - &self, - prompt_tokens: u32, - decode_tokens: u32, - ) -> Result { - self.allocate_inner(prompt_tokens, decode_tokens) - .map(|(blocks, slots)| BlockAllocation { - blocks, - slots, - prompt_tokens, - decode_tokens, + Ok(BlockAllocation { + allocated_blocks, + allocated_slots, + required_blocks, + required_slots, block_allocator: self.clone(), }) + } } pub(crate) fn free(&self, blocks: Vec) { diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 14e67fffed5..db09f9b4596 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -283,7 +283,7 @@ impl State { let decode_tokens = entry.request.stopping_parameters.max_new_tokens + self.speculate; match block_allocator - .allocate(entry.request.input_length, decode_tokens) + .block_allocation(entry.request.input_length, decode_tokens) { Err(_) => { // Entry is over budget @@ -294,7 +294,7 @@ impl State { } Ok(block_allocation) => { tracing::debug!("Allocation: {block_allocation:?}"); - max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + max_blocks = max(max_blocks, block_allocation.blocks().len() as u32); Some(block_allocation) } } @@ -313,8 +313,8 @@ impl State { let (blocks, slots) = match &block_allocation { None => (Vec::new(), Vec::new()), Some(block_allocation) => ( - block_allocation.blocks.clone(), - block_allocation.slots.clone(), + block_allocation.blocks().to_vec(), + block_allocation.slots().to_vec(), ), }; diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 0913f495b33..5fb9e11d6c7 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -347,7 +347,7 @@ async fn filter_batch( let (blocks, slots) = entry .block_allocation .as_ref() - .map(|alloc| (alloc.blocks.clone(), alloc.slots.clone())) + .map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec())) .unwrap_or((Vec::new(), Vec::new())); KeptRequest { From 73c39032142f7bc7243402e7ec9fa9f25e4bbcb3 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 11 Jun 2024 12:38:07 +0200 Subject: [PATCH 09/18] FlashCausalLM implem --- proto/v3/generate.proto | 9 +- router/client/src/v3/client.rs | 4 +- router/client/src/v3/mod.rs | 2 +- router/client/src/v3/sharded_client.rs | 4 +- router/src/infer/v3/scheduler.rs | 353 ++++++++++++------ .../models/flash_causal_lm.py | 23 +- server/text_generation_server/models/types.py | 2 +- server/text_generation_server/server.py | 6 +- 8 files changed, 274 insertions(+), 129 deletions(-) diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 3c9b1d7171d..8138e4fb4d5 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -224,11 +224,18 @@ message FilterBatchRequest { repeated uint64 terminated_request_ids = 3; } +message TerminatedGeneration { + // Request ID + uint64 id = 1; + // Generated text + GeneratedText generated_text = 2; +} + message FilterBatchResponse { /// Filtered Batch (cached) CachedBatch batch = 1; /// Terminated generations - repeated GeneratedText terminated_generations = 2; + repeated TerminatedGeneration terminated_generations = 2; } diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 90f43270e0f..1f8070cade6 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -92,7 +92,7 @@ impl Client { batch_id: u64, kept_requests: Vec, terminated_request_ids: Vec, - ) -> Result> { + ) -> Result<(Option, Vec)> { let request = tonic::Request::new(FilterBatchRequest { batch_id, kept_requests, @@ -100,7 +100,7 @@ impl Client { }) .inject_context(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); - Ok(filtered_batch.batch) + Ok((filtered_batch.batch, filtered_batch.terminated_generations)) } /// Warmup on a max size batch diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs index ea7486eec0a..9df17c50947 100644 --- a/router/client/src/v3/mod.rs +++ b/router/client/src/v3/mod.rs @@ -8,6 +8,6 @@ pub use client::Client; pub use pb::generate::v3::{ input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest, - NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, TerminatedGeneration, Tokens, }; pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index e1b35a212ab..3f11e101bf8 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -2,7 +2,7 @@ use crate::{v3, Health, ShardInfo}; use crate::{ClientError, Result}; -use crate::v3::{Chunk, InfoResponse, Input}; +use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration}; use async_trait::async_trait; use futures::future::join_all; use tonic::transport::Uri; @@ -86,7 +86,7 @@ impl ShardedClient { batch_id: u64, kept_requests: Vec, terminated_request_ids: Vec, - ) -> Result> { + ) -> Result<(Option, Vec)> { let futures: Vec<_> = self .clients .iter_mut() diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 5fb9e11d6c7..ee93c20ab35 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -5,12 +5,14 @@ use crate::infer::{ }; use crate::validation::ValidGenerateRequest; use crate::{FinishReason, PrefillToken, Token}; -use nohash_hasher::IntMap; +use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use text_generation_client::v3::{Batch, CachedBatch, Generation, KeptRequest, ShardedClient}; +use text_generation_client::v3::{ + Batch, CachedBatch, Generation, KeptRequest, ShardedClient, TerminatedGeneration, +}; use text_generation_client::ClientError; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; @@ -243,11 +245,38 @@ async fn prefill( generation_health.store(true, Ordering::SeqCst); let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + // Filter and send finished generations + let filtered_stream_responses = filter_send_ended_generations(generations, entries); + + // Iterate on intermediate generations + for (id, stream_responses) in filtered_stream_responses { + // Get entry + let entry = entries + .get_mut(&id) + .expect("ID not found in entries. This is a bug."); + + // Send intermediate responses + if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }) { + // Sending failed, remove entry + entries + .remove(&id) + .expect("ID not found in entries. This is a bug."); + } + } // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries, false).await; + let next_batch = match next_batch { + Some(batch) if batch.size as usize != entries.len() => { + let (filtered_batch, _) = + filter_batch(client, batch, entries, &IntMap::default()).await; + filtered_batch + } + batch => batch, + }; metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); @@ -285,13 +314,32 @@ async fn decode( generation_health.store(true, Ordering::SeqCst); let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - - let updated = filter_update_allocations(entries).await; - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries, updated).await; + // Filter and send finished generations + let mut filtered_stream_responses = filter_send_ended_generations(generations, entries); + // Send `StreamResponseInfer::Intermediate` messages for entries that don't need to be + // re-allocated, + // Allocated new blocks for entries that go over their allocation + // Filter entries that couldn't be re-allocated and add them to `terminated_entries` + let (force_update, terminated_entries) = + filter_send_update_allocations(entries, &mut filtered_stream_responses); + + let next_batch = match next_batch { + // Run Only on re-allocation or if entries were filtered + Some(batch) if batch.size as usize != entries.len() || force_update => { + // Filter next batch: remove requests that were stopped and update blocks/slots + let (filtered_batch, terminated_generations) = + filter_batch(client, batch, entries, &terminated_entries).await; + send_terminated_generations( + terminated_generations, + terminated_entries, + filtered_stream_responses, + ); + + filtered_batch + } + batch => batch, + }; if let Some(concat_duration) = timings.concat { metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); @@ -320,27 +368,20 @@ async fn decode( #[instrument(skip_all)] async fn filter_batch( client: &mut ShardedClient, - next_batch: Option, + batch: CachedBatch, entries: &IntMap, - force_update: bool, -) -> Option { - let batch = next_batch?; - - // No need to filter - if batch.size as usize == entries.len() && !force_update { - return Some(batch); - } - + terminated_entries: &IntMap, +) -> (Option, Vec) { let id = batch.id; - if entries.is_empty() { + if entries.is_empty() && terminated_entries.is_empty() { // All requests have been filtered out // Next batch is now empty // Clear it from the Python shards cache // We unwrap here as we need to panic since we cannot recover if this method fails client.clear_cache(Some(id)).await.unwrap(); - None + Default::default() } else { - // Filter Python shard cache + // Collect new blocks/slots let updated_requests = entries .iter() .map(|(request_id, entry)| { @@ -348,7 +389,7 @@ async fn filter_batch( .block_allocation .as_ref() .map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec())) - .unwrap_or((Vec::new(), Vec::new())); + .unwrap_or_default(); KeptRequest { id: *request_id, @@ -358,111 +399,207 @@ async fn filter_batch( }) .collect(); + // Filter Python shard cache // We unwrap here as we need to panic since we cannot recover if this method fails client - .filter_batch(id, updated_requests, Vec::new()) + .filter_batch( + id, + updated_requests, + terminated_entries.keys().map(|v| *v).collect(), + ) .await .unwrap() } } -/// Send one or multiple `InferStreamResponse` to Infer for all `entries` -/// and filter entries +/// #[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - generations.into_iter().for_each(|generation| { +fn send_terminated_generations( + terminated_generations: Vec, + terminated_entries: IntMap, + mut stream_responses: IntMap>, +) { + // Receive final message for terminated generations + 'terminated_generations: for terminated_generation in terminated_generations { + let id = terminated_generation.id; + // Get entry for this generation + let entry = terminated_entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + // Get previous `InferStreamResponse` for this generation + let stream_responses = stream_responses + .remove(&id) + .expect("ID not found in stream_responses. This is a bug."); + + // Peekable iterator to know when we are at the last `InferStreamResponse` + let mut iterator = stream_responses.into_iter().peekable(); + + while let Some(stream_response) = iterator.next() { + let response = if iterator.peek().is_none() { + // Last `InferStreamResponse::Intermediate` + let (token, top_tokens) = match stream_response { + InferStreamResponse::Intermediate { token, top_tokens } => (token, top_tokens), + _ => unreachable!(), + }; + // Modify it to be a `InferStreamResponse::End` with the new OutOfResources finish + // reason + InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from( + terminated_generation + .generated_text + .clone() + .expect("Generated Text is None. This is a bug."), + ), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + } + } else { + stream_response + }; + + // Send responses + if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }) { + continue 'terminated_generations; + } + } + } +} + +/// Send `InferStreamResponse::End` to `Infer` for finished entries and remove them from `entries` +/// Returns filtered `InferStreamResponse::Intermediate` generations +#[instrument(skip_all)] +fn filter_send_ended_generations( + generations: Vec, + entries: &mut IntMap, +) -> IntMap> { + generations.into_iter().filter_map(|generation| { let id = generation.request_id; // Get entry // We can `expect` here as the request id should always be in the entries let entry = entries .get_mut(&id) .expect("ID not found in entries. This is a bug."); - entry.cache_length = generation.cache_length; // Create and enter a span to link this function back to the entry let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); - // Send generation responses back to the infer task - // If the receive an error from the Flume channel, it means that the client dropped the - // request and we need to stop generating hence why we unwrap_or(true) - let stopped = send_responses(generation, entry).map_err(|err| { - tracing::error!("Entry response channel error."); + + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - err - }).unwrap_or(true); - if stopped { + // Remove from entries and filter entries.remove(&id).expect("ID not found in entries. This is a bug."); + return None; } - }); + + // Update cache length + entry.cache_length = generation.cache_length; + + let (finished, stream_responses) = map_generation(generation, entry); + // If the generation has ended for this request, we send the responses to the channel and + // remove the entry to drop it and free its blocks + if finished { + let _ = send_stream_responses(stream_responses, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }); + // Remove from entries and filter + entries.remove(&id).expect("ID not found in entries. This is a bug."); + return None; + } + + Some((id, stream_responses)) + }).collect() +} + +/// Send `InferStreamResponse` to `Infer` through an `Entry` response channel +#[instrument(skip_all)] +fn send_stream_responses( + stream_responses: Vec, + entry: &Entry, +) -> Result<(), Box>>> { + for response in stream_responses { + entry.response_tx.send(Ok(response))?; + } + Ok(()) } /// Check if block allocations need to be extended -/// If we don't have enough blocks, request will be filtered with an OutOfPages error +/// If we don't have enough blocks, request will be filtered with be added to an IntMap of +/// terminated entries. +/// If at least one entry allocation was extended, we return true to force an update #[instrument(skip_all)] -async fn filter_update_allocations(entries: &mut IntMap) -> bool { - let ids: Vec = entries - .iter() - .filter_map(|(id, entry)| { - entry - .block_allocation - .as_ref() - .map(|block_allocation| { - if entry.cache_length > block_allocation.len() as u32 { - // We need to re-allocate - Some(*id) - } else { - None - } - }) - .unwrap_or(None) - }) - .collect(); +fn filter_send_update_allocations( + entries: &mut IntMap, + stream_responses: &mut IntMap>, +) -> (bool, IntMap) { + let mut updated = false; - for id in ids.iter() { - // Get entry - // We can `expect` here as the request id should always be in the entries - let extension = { - let entry = entries - .get_mut(id) - .expect("ID not found in entries. This is a bug."); - entry - .block_allocation - .as_mut() - .expect("We checked that the block allocation exists above") - .extend() - }; + let ids: Vec = entries.keys().map(|v| *v).collect(); + let mut terminated_entries = + IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default()); - if extension.is_err() { - let entry = entries - .remove(id) - .expect("ID not found in entries. This is a bug."); + for id in &ids { + let entry = entries + .get_mut(id) + .expect("ID not found in entries. This is a bug."); - // Create and enter a span to link this function back to the entry - let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); - let err = InferError::OutOfPages; - metrics::increment_counter!("tgi_request_failure", "err" => "out_of_pages"); - tracing::error!("{err}"); + if let Some(block_allocation) = entry.block_allocation.as_mut() { + // Check if allocation can handle the current cache_length + if entry.cache_length > block_allocation.len() as u32 { + updated = true; + + // Extend allocation by asking for a new block + if let Err(err) = block_allocation.extend() { + // Failed to extend allocation + tracing::error!("Failed to extend allocation: {err}"); + metrics::increment_counter!("tgi_request_failure", "err" => "out_of_resources"); + + // Remove entry + let mut entry = entries + .remove(id) + .expect("ID not found in entries. This is a bug."); + // Clear block allocation + entry.block_allocation = None; + // Add it to terminated entries + terminated_entries.insert(*id, entry); + // Skip the rest of the logic to not send the intermediate messages + // This entry will be terminated and we will need to edit the last intermediate + // response to add the complete generated text + continue; + } + } + } + let stream_response = stream_responses + .remove(id) + .expect("ID not found in stream_responses. This is a bug."); - // unwrap_or is valid here as we don't care if the receiver is gone. - entry.response_tx.send(Err(err)).unwrap_or(()); + // Send intermediate responses + if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }) { + // Sending failed, remove entry + entries + .remove(id) + .expect("ID not found in entries. This is a bug."); } } - // If ids is not empty, we need to update - !ids.is_empty() + (updated, terminated_entries) } -/// Send responses through the `entry` response channel -fn send_responses( - generation: Generation, - entry: &Entry, -) -> Result>>> { - // Return directly if the channel is disconnected - if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - return Ok(true); - } - - let mut stopped = false; +/// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>` +fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec) { + let mut finished = false; + let mut stream_responses = Vec::with_capacity(16); if let Some(prefill_tokens) = generation.prefill_tokens { // Create Token objects @@ -475,10 +612,8 @@ fn send_responses( .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) .collect(); - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + // Push to stream_responses + stream_responses.push(InferStreamResponse::Prefill(prefill_tokens)); } // Create last Token @@ -520,26 +655,24 @@ fn send_responses( match (&generation.generated_text, iterator.peek()) { (Some(generated_text), None) => { // Generation has ended - stopped = true; - // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { + finished = true; + // Push to stream_responses + stream_responses.push(InferStreamResponse::End { token, top_tokens, generated_text: GeneratedText::from(generated_text.clone()), queued: entry.queue_time, start: entry.batch_time.unwrap(), - }))?; + }); } _ => { - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + // Push to stream_responses + stream_responses.push(InferStreamResponse::Intermediate { token, top_tokens }); } } } - Ok(stopped) + (finished, stream_responses) } /// Send errors to Infer for all `entries` diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 0bd9357f49a..e8fd8b1667c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -402,10 +402,7 @@ def filter( model: "FlashCausalLM", kept_requests: List[generate_pb2.KeptRequest], terminated_request_ids: List[int], - ) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.GeneratedText]]: - if len(kept_requests) == 0: - raise ValueError("Batch must have at least one request") - + ) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: terminated_generations = [] for request_id in terminated_request_ids: idx = self.requests_idx_mapping[request_id] @@ -421,13 +418,19 @@ def filter( read_offset=len(all_input_ids) - stopping_criteria.current_tokens, skip_special_tokens=True, ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - generate_pb2.FINISH_REASON_TERMINATED, - seed if do_sample else None, + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed if do_sample else None, + ), + ) ) - terminated_generations.append(generated_text) + if not kept_requests: + return None, terminated_generations device = self.input_ids.device diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index c19f804eada..0b7868fceb3 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -33,7 +33,7 @@ def filter( model, kept_requests: List[generate_pb2.KeptRequest], terminated_request_ids: List[int], - ) -> Tuple["Batch", List[generate_pb2.GeneratedText]]: + ) -> Tuple[Optional["Batch"], List[generate_pb2.TerminatedGeneration]]: raise NotImplementedError @classmethod diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 86df66e7eac..14297669e5a 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -86,10 +86,12 @@ async def FilterBatch(self, request, context): filtered_batch, terminated_generations = batch.filter( self.model, request.kept_requests, request.terminated_request_ids ) - self.cache.set(filtered_batch) + if filtered_batch is not None: + self.cache.set(filtered_batch) return generate_pb2.FilterBatchResponse( - batch=filtered_batch.to_pb(), terminated_generations=terminated_generations + batch=filtered_batch.to_pb() if filtered_batch is not None else None, + terminated_generations=terminated_generations, ) async def Warmup(self, request, context): From 37266e2dbb8b5f52cff80cf837a5e479c51172fe Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 11 Jun 2024 17:11:16 +0200 Subject: [PATCH 10/18] fix rust and python unit-tests --- .github/workflows/trufflehog.yml | 1 - router/src/infer/v3/block_allocator.rs | 52 +++++++++++++++---- router/src/infer/v3/queue.rs | 10 ++-- router/src/infer/v3/scheduler.rs | 35 ++++++------- server/tests/models/test_bloom.py | 26 +++++----- server/tests/models/test_causal_lm.py | 33 ++++++------ server/tests/models/test_seq2seq_lm.py | 26 +++++----- .../models/causal_lm.py | 48 ++++++++++++++--- .../models/idefics_causal_lm.py | 52 ++++++++++++++++--- server/text_generation_server/models/mamba.py | 48 ++++++++++++++--- .../models/seq2seq_lm.py | 49 ++++++++++++++--- .../models/vlm_causal_lm.py | 20 ++++--- 12 files changed, 288 insertions(+), 112 deletions(-) diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml index 8bc60eff6ff..b406d43b8f0 100644 --- a/.github/workflows/trufflehog.yml +++ b/.github/workflows/trufflehog.yml @@ -16,4 +16,3 @@ jobs: fetch-depth: 0 - name: Secret Scanning uses: trufflesecurity/trufflehog@main - diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index a084a5056e9..563f173fe3a 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,7 +1,8 @@ -use std::sync::{Arc, Mutex}; +use std::fmt::Formatter; +use std::sync::{Arc, Mutex, TryLockError}; use thiserror::Error; -#[derive(Debug, Clone)] +#[derive(Clone)] pub(crate) struct BlockAllocation { allocated_blocks: Vec, allocated_slots: Vec, @@ -53,7 +54,19 @@ impl Drop for BlockAllocation { } } -#[derive(Debug, Clone)] +impl std::fmt::Debug for BlockAllocation { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BlockAllocation") + .field("allocated_blocks", &self.allocated_blocks.len()) + .field("allocated_slots", &self.allocated_slots.len()) + .field("required_blocks", &self.required_blocks) + .field("required_slots", &self.required_slots) + .field("block_allocator", &self.block_allocator) + .finish() + } +} + +#[derive(Clone)] pub(crate) struct BlockAllocator { free_blocks: Arc>>, block_size: u32, @@ -129,8 +142,7 @@ impl BlockAllocator { Err(AllocationError::NotEnoughPages) } else { let n_free_blocks = free_blocks.len(); - let allocated_blocks = - free_blocks.split_off(n_free_blocks - clipped_required_blocks); + let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks); let allocated_blocks = if repeats != 1 { let mut allocated_blocks = allocated_blocks.repeat(repeats); @@ -140,9 +152,8 @@ impl BlockAllocator { allocated_blocks }; - let mut allocated_slots = Vec::with_capacity( - allocated_blocks.len() * self.block_size as usize * repeats, - ); + let mut allocated_slots = + Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats); let required_slots = (prompt_tokens + decode_tokens) as usize; @@ -166,7 +177,30 @@ impl BlockAllocator { } pub(crate) fn free(&self, blocks: Vec) { - self.free_blocks.lock().expect("Lock could not be acquired. This is a bug.").extend(blocks) + self.free_blocks + .lock() + .expect("Lock could not be acquired. This is a bug.") + .extend(blocks) + } +} + +impl std::fmt::Debug for BlockAllocator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut d = f.debug_struct("BlockAllocator"); + d.field("block_size", &self.block_size) + .field("window_size", &self.window_size); + match self.free_blocks.try_lock() { + Ok(guard) => { + d.field("free_blocks", &(*guard).len()); + } + Err(TryLockError::Poisoned(err)) => { + d.field("free_blocks", &(**err.get_ref()).len()); + } + Err(TryLockError::WouldBlock) => { + d.field("free_blocks", &format_args!("")); + } + }; + d.finish() } } diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index db09f9b4596..d8085800173 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -275,7 +275,9 @@ impl State { if prefill_tokens > prefill_token_budget { // Entry is over budget // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + tracing::debug!( + "Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget}" + ); self.entries.push_front((id, entry)); break; } @@ -456,7 +458,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: vec![], - input_length: 0, + input_length: 1, truncate: 0, decoder_input_details: false, parameters: ValidParameters { @@ -567,7 +569,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0, 2); + let mut state = State::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -689,7 +691,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2, 16); + let queue = Queue::new(true, 1, None, 2, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index ee93c20ab35..c03328b2e70 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -256,11 +256,7 @@ async fn prefill( .expect("ID not found in entries. This is a bug."); // Send intermediate responses - if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| { - tracing::error!("Entry response channel error."); - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - err - }) { + if send_stream_responses(stream_responses, entry).is_err() { // Sending failed, remove entry entries .remove(&id) @@ -405,7 +401,7 @@ async fn filter_batch( .filter_batch( id, updated_requests, - terminated_entries.keys().map(|v| *v).collect(), + terminated_entries.keys().copied().collect(), ) .await .unwrap() @@ -460,11 +456,14 @@ fn send_terminated_generations( }; // Send responses - if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| { + let send_result = entry.response_tx.send(Ok(response)).map_err(|err| { tracing::error!("Entry response channel error."); metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); err - }) { + }); + + if send_result.is_err() { + // The channel is dropped, skip the rest of the messages continue 'terminated_generations; } } @@ -504,11 +503,7 @@ fn filter_send_ended_generations( // If the generation has ended for this request, we send the responses to the channel and // remove the entry to drop it and free its blocks if finished { - let _ = send_stream_responses(stream_responses, entry).map_err(|err| { - tracing::error!("Entry response channel error."); - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - err - }); + let _ = send_stream_responses(stream_responses, entry); // Remove from entries and filter entries.remove(&id).expect("ID not found in entries. This is a bug."); return None; @@ -525,7 +520,11 @@ fn send_stream_responses( entry: &Entry, ) -> Result<(), Box>>> { for response in stream_responses { - entry.response_tx.send(Ok(response))?; + entry.response_tx.send(Ok(response)).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + })?; } Ok(()) } @@ -541,7 +540,7 @@ fn filter_send_update_allocations( ) -> (bool, IntMap) { let mut updated = false; - let ids: Vec = entries.keys().map(|v| *v).collect(); + let ids: Vec = entries.keys().copied().collect(); let mut terminated_entries = IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default()); @@ -581,11 +580,7 @@ fn filter_send_update_allocations( .expect("ID not found in stream_responses. This is a bug."); // Send intermediate responses - if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| { - tracing::error!("Entry response channel error."); - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - err - }) { + if send_stream_responses(stream_response, entry).is_err() { // Sending failed, remove entry entries .remove(id) diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 0daa5f418eb..78bde639faf 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -197,8 +197,10 @@ def test_causal_lm_generate_token_completion_multi( # Copy stopping_criterias before filtering stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() - next_batch = next_batch.filter( - [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + next_batch, _ = next_batch.filter( + default_bloom, + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], + [], ) for _ in range( @@ -307,15 +309,13 @@ def test_batch_concatenate( == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter( + next_batch, _ = next_batch.filter( + default_bloom, [ - generate_pb2.UpdatedRequest( - id=next_batch.requests[0].id, blocks=[], slots=[] - ), - generate_pb2.UpdatedRequest( - id=next_batch.requests[1].id, blocks=[], slots=[] - ), - ] + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + ], + [], ) for _ in range( @@ -339,8 +339,10 @@ def test_batch_concatenate( == default_bloom_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter( - [generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])] + next_batch, _ = next_batch.filter( + default_bloom, + [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])], + [], ) for _ in range( diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 547da81f76c..6716606cb97 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -198,8 +198,10 @@ def test_causal_lm_generate_token_completion_multi( default_multi_requests_causal_lm_batch.stopping_criterias.copy() ) - next_batch = next_batch.filter( - [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + next_batch, _ = next_batch.filter( + default_causal_lm, + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], + [], ) for _ in range( @@ -307,15 +309,13 @@ def test_batch_concatenate( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter( + next_batch, _ = next_batch.filter( + default_causal_lm, [ - generate_pb2.UpdatedRequest( - id=next_batch.requests[0].id, blocks=[], slots=[] - ), - generate_pb2.UpdatedRequest( - id=next_batch.requests[1].id, blocks=[], slots=[] - ), - ] + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + ], + [], ) for _ in range( @@ -337,15 +337,12 @@ def test_batch_concatenate( == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter( + next_batch, _ = next_batch.filter( + default_causal_lm, [ - generate_pb2.UpdatedRequest( - id=next_batch.requests[0].id, blocks=[], slots=[] - ), - generate_pb2.UpdatedRequest( - id=next_batch.requests[1].id, blocks=[], slots=[] - ), - ] + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + ], + [], ) for _ in range( diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 17b5fa50c41..f1d2bb75df3 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -206,8 +206,10 @@ def test_seq2seq_lm_generate_token_completion_multi( ) assert generations[1].generated_text.generated_tokens == 5 - next_batch = next_batch.filter( - [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + next_batch, _ = next_batch.filter( + default_seq2seq_lm, + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], + [], ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) @@ -341,15 +343,13 @@ def test_batch_concatenate( ) assert generations[2].generated_text.generated_tokens == 5 - next_batch = next_batch.filter( + next_batch, _ = next_batch.filter( + default_seq2seq_lm, [ - generate_pb2.UpdatedRequest( - id=next_batch.requests[0].id, blocks=[], slots=[] - ), - generate_pb2.UpdatedRequest( - id=next_batch.requests[1].id, blocks=[], slots=[] - ), - ] + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + ], + [], ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) @@ -360,8 +360,10 @@ def test_batch_concatenate( assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].generated_text.generated_tokens == 7 - next_batch = next_batch.filter( - [generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])] + next_batch, _ = next_batch.filter( + default_seq2seq_lm, + [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])], + [], ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 40a4f100221..f3b94e8cbbb 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -159,14 +159,48 @@ def from_pb( @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.KeptRequest] - ) -> Optional["CausalLMBatch"]: - request_ids = [r.id for r in updated_requests] + self, + model: "CausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["CausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] + + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + if not kept_requests: + return None, terminated_generations - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -262,7 +296,7 @@ def filter( self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens - return self + return self, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 495a47e5e22..f92378cbcac 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -215,15 +215,51 @@ def from_pb_processor( @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.KeptRequest] - ) -> Optional["IdeficsCausalLMBatch"]: - request_ids = [r.id for r in updated_requests] + self, + model: "IdeficsCausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[ + Optional["IdeficsCausalLMBatch"], List[generate_pb2.TerminatedGeneration] + ]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] + + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + + if not kept_requests: + return None, terminated_generations - # It deletes requests from the batch. For instance when client lost connection - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -330,7 +366,7 @@ def filter( self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens - return self + return self, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 0340ca5519e..64cb739e1dd 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -196,14 +196,48 @@ def from_pb( ) def filter( - self, updated_requests: List[generate_pb2.KeptRequest] - ) -> Optional["MambaBatch"]: - request_ids = [r.id for r in updated_requests] + self, + model: "Mamba", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["MambaBatch"], List[generate_pb2.TerminatedGeneration]]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] + + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + if not kept_requests: + return None, terminated_generations - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -278,7 +312,7 @@ def filter( :, indices ] self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices] - return self + return self, terminated_generations @classmethod def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 77407118d44..3cf874fac26 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -167,14 +167,49 @@ def from_pb( @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.KeptRequest] - ) -> Optional["Seq2SeqLMBatch"]: - request_ids = [r.id for r in updated_requests] + self, + model: "Seq2SeqLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["Seq2SeqLMBatch"], List[generate_pb2.TerminatedGeneration]]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_decoder_input_ids = self.all_decoder_input_ids[idx] + decoder_input_length = self.decoder_input_lengths[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] + + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_decoder_input_ids, + prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1, + read_offset=len(all_decoder_input_ids) - decoder_input_length, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + if not kept_requests: + return None, terminated_generations - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -281,7 +316,7 @@ def filter( self.padding_right_offset = padding_right_offset self.max_tokens = max_tokens - return self + return self, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index e3d0bee8690..cee8b45fd9b 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -123,13 +123,19 @@ def concatenate(cls, batches): @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.KeptRequest] - ) -> Optional["VlmCausalLMBatch"]: - batch = super().filter(updated_requests) - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - return batch + self, + model: "VlmCausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["VlmCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: + batch, terminated_generations = super().filter( + model, kept_requests, terminated_request_ids + ) + if batch is not None: + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + return batch, terminated_generations @classmethod def batch_tokenized_inputs( From c2fb459bc1a3a207308243f5fcc32bf6781618d0 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 11 Jun 2024 18:40:38 +0200 Subject: [PATCH 11/18] fix windowing --- router/src/infer/v3/block_allocator.rs | 101 ++++++++++++++----------- router/src/infer/v3/scheduler.rs | 9 ++- 2 files changed, 62 insertions(+), 48 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 563f173fe3a..18480dbb2b8 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -24,6 +24,9 @@ impl BlockAllocation { &self.allocated_slots } + /// Extend an allocation by adding a new block + /// If the allocation length > window size, repeats blocks and slots to cover the + /// whole `required_blocks` and `required_slots` pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { let (block, slots) = self.block_allocator.allocate_block()?; // Add block and slots to current allocation @@ -48,6 +51,7 @@ impl BlockAllocation { } impl Drop for BlockAllocation { + /// Free the blocks fn drop(&mut self) { let allocated_blocks = std::mem::take(&mut self.allocated_blocks); self.block_allocator.free(allocated_blocks) @@ -114,66 +118,71 @@ impl BlockAllocator { let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size; // prompt blocks + a single block for decode let required_blocks = required_prompt_blocks + 1; + let required_slots = required_blocks * self.block_size; + + // Slots and blocks required for the whole request + let total_slots = prompt_tokens + decode_tokens; + let total_required_blocks = (total_slots + self.block_size - 1) / self.block_size; let (clipped_required_blocks, repeats) = match self.window_size { - // Nothing to do - None => (required_blocks, 1), - Some(window_size) => { + Some(window_size) if required_slots >= window_size => { // Number of blocks for this window size let window_size_blocks = (window_size + self.block_size - 1) / self.block_size; + // Number of times we will need to repeat blocks to cover the total allocation + let repeats = (total_slots + window_size - 1) / window_size; + (window_size_blocks, repeats) + } + // Nothing to do + _ => (required_blocks, 1), + }; - if required_blocks > window_size_blocks { - // Number of times we will need to repeat blocks to cover the required allocation - let repeats = (required_blocks + window_size_blocks - 1) / window_size_blocks; - (window_size_blocks, repeats) - } else { - (required_blocks, 1) - } + // Scoped to drop the lock early + let allocated_blocks = { + let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); + let clipped_required_blocks = clipped_required_blocks as usize; + + if clipped_required_blocks > free_blocks.len() { + // Not enough blocks to cover this allocation + // Early return + return Err(AllocationError::NotEnoughPages); } + + // Take the blocks + let n_free_blocks = free_blocks.len(); + free_blocks.split_off(n_free_blocks - clipped_required_blocks) }; let repeats = repeats as usize; - let required_blocks = required_blocks as usize; - let clipped_required_blocks = clipped_required_blocks as usize; + let total_slots = total_slots as usize; + let total_required_blocks = total_required_blocks as usize; - let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); - - if clipped_required_blocks > free_blocks.len() { - Err(AllocationError::NotEnoughPages) + let allocated_blocks = if repeats != 1 { + let mut allocated_blocks = allocated_blocks.repeat(repeats); + allocated_blocks.truncate(total_required_blocks); + allocated_blocks } else { - let n_free_blocks = free_blocks.len(); - let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks); - - let allocated_blocks = if repeats != 1 { - let mut allocated_blocks = allocated_blocks.repeat(repeats); - allocated_blocks.truncate(required_blocks); - allocated_blocks - } else { - allocated_blocks - }; - - let mut allocated_slots = - Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats); - - let required_slots = (prompt_tokens + decode_tokens) as usize; - - 'slots: for block_id in allocated_blocks.iter() { - for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { - allocated_slots.push(s); - if allocated_slots.len() > required_slots { - break 'slots; - } + allocated_blocks + }; + + let mut allocated_slots = + Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats); + + 'slots: for block_id in allocated_blocks.iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + allocated_slots.push(s); + if allocated_slots.len() > total_slots { + break 'slots; } } - - Ok(BlockAllocation { - allocated_blocks, - allocated_slots, - required_blocks, - required_slots, - block_allocator: self.clone(), - }) } + + Ok(BlockAllocation { + allocated_blocks, + allocated_slots, + required_blocks: total_required_blocks, + required_slots: total_slots, + block_allocator: self.clone(), + }) } pub(crate) fn free(&self, blocks: Vec) { diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index c03328b2e70..6e5ffa7efb0 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -361,6 +361,7 @@ async fn decode( } /// Filter a `batch` and remove all requests not present in `entries` +/// Ask the server to generate the full texts for entries in `terminated_entries` #[instrument(skip_all)] async fn filter_batch( client: &mut ShardedClient, @@ -408,7 +409,10 @@ async fn filter_batch( } } -/// +/// Send `InferStreamResponse::Intermediate` and the final `InferStreamResponse::End` messages +/// to terminated requests +/// It modifies the last `InferStreamResponse::Intermediate` to add the final full text in +/// `terminated_generations` #[instrument(skip_all)] fn send_terminated_generations( terminated_generations: Vec, @@ -530,7 +534,7 @@ fn send_stream_responses( } /// Check if block allocations need to be extended -/// If we don't have enough blocks, request will be filtered with be added to an IntMap of +/// If we don't have enough blocks, request will be filtered and added to an IntMap of /// terminated entries. /// If at least one entry allocation was extended, we return true to force an update #[instrument(skip_all)] @@ -592,6 +596,7 @@ fn filter_send_update_allocations( } /// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>` +/// `bool` is `true` if the generation is finished fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec) { let mut finished = false; let mut stream_responses = Vec::with_capacity(16); From 9ac7b7bc521495b3b6335240d9b3311c79a47c7f Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 12 Jun 2024 11:50:31 +0200 Subject: [PATCH 12/18] remove slots from grpc --- benchmark/src/generation.rs | 1 - proto/v3/generate.proto | 4 -- router/client/src/v3/client.rs | 1 - router/client/src/v3/sharded_client.rs | 1 - router/src/infer/v3/block_allocator.rs | 44 +++++-------------- router/src/infer/v3/queue.rs | 18 ++++---- router/src/infer/v3/scheduler.rs | 15 +++---- .../models/flash_causal_lm.py | 43 ++++++++++-------- 8 files changed, 52 insertions(+), 75 deletions(-) diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index b82d23ba41b..e5fbdca4a2a 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -156,7 +156,6 @@ async fn prefill( }), top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], - slots: vec![], }) .collect(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 8138e4fb4d5..c6e02034b27 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -132,8 +132,6 @@ message Request { uint32 top_n_tokens = 7; /// Paged attention blocks repeated uint32 blocks = 9; - /// Paged attention slots - repeated uint32 slots = 10; } message Batch { @@ -208,8 +206,6 @@ message KeptRequest { uint64 id = 1; /// Paged attention blocks repeated uint32 blocks = 2; - /// Paged attention slots - repeated uint32 slots = 3; } /// kept_requests + terminated_request_ids might not cover all requests from the diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 1f8070cade6..03efd4f5d47 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -157,7 +157,6 @@ impl Client { truncate, // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], - slots: vec![], // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 3f11e101bf8..1f9ec3ad1ed 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -250,7 +250,6 @@ impl Health for ShardedClient { top_n_tokens: 0, // Block 0 is reserved for health checks blocks: vec![0], - slots: (0..16).collect(), }; let batch = Batch { id: u64::MAX, diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 18480dbb2b8..e19450e3820 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,11 +1,12 @@ +use std::cmp::min; use std::fmt::Formatter; use std::sync::{Arc, Mutex, TryLockError}; use thiserror::Error; #[derive(Clone)] pub(crate) struct BlockAllocation { + block_size: usize, allocated_blocks: Vec, - allocated_slots: Vec, required_blocks: usize, required_slots: usize, block_allocator: BlockAllocator, @@ -13,25 +14,20 @@ pub(crate) struct BlockAllocation { impl BlockAllocation { pub(crate) fn len(&self) -> usize { - self.allocated_slots.len() + self.allocated_blocks.len() * self.block_size } pub(crate) fn blocks(&self) -> &[u32] { &self.allocated_blocks } - pub(crate) fn slots(&self) -> &[u32] { - &self.allocated_slots - } - /// Extend an allocation by adding a new block /// If the allocation length > window size, repeats blocks and slots to cover the /// whole `required_blocks` and `required_slots` pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { - let (block, slots) = self.block_allocator.allocate_block()?; + let block = self.block_allocator.allocate_block()?; // Add block and slots to current allocation self.allocated_blocks.push(block); - self.allocated_slots.extend(slots); if let Some(window_size) = self.block_allocator.window_size { // if we have more slots than the window size, @@ -41,8 +37,6 @@ impl BlockAllocation { let repeats = (self.required_slots + window_size - 1) / window_size; self.allocated_blocks = self.allocated_blocks.repeat(repeats); self.allocated_blocks.truncate(self.required_blocks); - self.allocated_slots = self.allocated_slots.repeat(repeats); - self.allocated_slots.truncate(self.required_slots); } } @@ -62,7 +56,6 @@ impl std::fmt::Debug for BlockAllocation { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("BlockAllocation") .field("allocated_blocks", &self.allocated_blocks.len()) - .field("allocated_slots", &self.allocated_slots.len()) .field("required_blocks", &self.required_blocks) .field("required_slots", &self.required_slots) .field("block_allocator", &self.block_allocator) @@ -94,30 +87,29 @@ impl BlockAllocator { } } - fn allocate_block(&self) -> Result<(u32, Vec), AllocationError> { + fn allocate_block(&self) -> Result { let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); if free_blocks.is_empty() { return Err(AllocationError::NotEnoughPages); } - let block_id = free_blocks.pop().unwrap(); - let slots = ((block_id * self.block_size)..((block_id + 1) * self.block_size)).collect(); - Ok((block_id, slots)) + Ok(free_blocks.pop().unwrap()) } /// For prompt tokens, we allocate enough blocks to cover all tokens - /// For decode tokens, we allocate block by block + /// For decode tokens, we allocate min(decode_blocks, 16) blocks /// - /// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots + /// If allocation > window size, we repeat blocks and slots pub(crate) fn block_allocation( &self, prompt_tokens: u32, decode_tokens: u32, ) -> Result { let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size; - // prompt blocks + a single block for decode - let required_blocks = required_prompt_blocks + 1; + // prompt blocks + 16 blocks for decode + let decode_blocks = (decode_tokens + self.block_size - 1) / self.block_size; + let required_blocks = required_prompt_blocks + min(decode_blocks, 16); let required_slots = required_blocks * self.block_size; // Slots and blocks required for the whole request @@ -164,21 +156,9 @@ impl BlockAllocator { allocated_blocks }; - let mut allocated_slots = - Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats); - - 'slots: for block_id in allocated_blocks.iter() { - for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { - allocated_slots.push(s); - if allocated_slots.len() > total_slots { - break 'slots; - } - } - } - Ok(BlockAllocation { + block_size: self.block_size as usize, allocated_blocks, - allocated_slots, required_blocks: total_required_blocks, required_slots: total_slots, block_allocator: self.clone(), diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index d8085800173..43d2bdd84f3 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -224,6 +224,11 @@ impl State { } } + // Check if max_size == 0 + if max_size == Some(0) { + return None; + } + // Pad prefill_token_budget to be a multiple of block size let prefill_token_budget = ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; @@ -312,14 +317,10 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); - let (blocks, slots) = match &block_allocation { - None => (Vec::new(), Vec::new()), - Some(block_allocation) => ( - block_allocation.blocks().to_vec(), - block_allocation.slots().to_vec(), - ), - }; - + let blocks = block_allocation + .as_ref() + .map(|block_allocation| block_allocation.blocks().to_vec()) + .unwrap_or_default(); entry.block_allocation = block_allocation; batch_requests.push(Request { @@ -338,7 +339,6 @@ impl State { )), top_n_tokens: entry.request.top_n_tokens, blocks, - slots, }); // Set batch_time entry.batch_time = Some(Instant::now()); diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 6e5ffa7efb0..50b33951261 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -164,7 +164,7 @@ pub(crate) async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue @@ -382,16 +382,15 @@ async fn filter_batch( let updated_requests = entries .iter() .map(|(request_id, entry)| { - let (blocks, slots) = entry + let blocks = entry .block_allocation .as_ref() - .map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec())) + .map(|alloc| alloc.blocks().to_vec()) .unwrap_or_default(); KeptRequest { id: *request_id, blocks, - slots, } }) .collect(); @@ -991,10 +990,10 @@ mod tests { content: "You are a friendly chatbot who always responds in the style of a pirate" .to_string(), }] - .iter() - .chain(&example_chat) - .cloned() - .collect::>(); + .iter() + .chain(&example_chat) + .cloned() + .collect::>(); let test_default_templates = vec![ ChatTemplateTestItem { diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e8fd8b1667c..1bf9b7a5356 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -133,9 +133,12 @@ def batch_tokenized_inputs( batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) max_truncation = max(max_truncation, r.truncate) + logger.error(batch_inputs) + batch_tokenized_inputs = tokenizer( batch_inputs, truncation=True, max_length=max_truncation )["input_ids"] + logger.error(batch_tokenized_inputs) return batch_tokenized_inputs @classmethod @@ -179,7 +182,7 @@ def from_tokenized( max_blocks = 0 block_tables = [] - flat_slots = [] + flat_blocks = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -231,24 +234,18 @@ def from_tokenized( request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] - request_slots = [ - s - for b in request_blocks - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] else: request_blocks = r.blocks - request_slots = r.slots block_tables.append(request_blocks) num_blocks += len(request_blocks) request_slot_indices = torch.arange( - len(flat_slots), - len(flat_slots) + input_length, + len(flat_blocks) * BLOCK_SIZE, + (len(flat_blocks) * BLOCK_SIZE) + input_length, dtype=torch.int64, ) - flat_slots.extend(request_slots) + flat_blocks.extend(request_blocks) slot_indices.append(request_slot_indices) # Create tensor to slice into the kv tensor in prefill @@ -347,7 +344,13 @@ def from_tokenized( top_n_tokens, device=device, dtype=torch.int64 ) - slots = torch.tensor(flat_slots, dtype=torch.int64, device=device) + flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device) + + slots = ( + (flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T + + torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64) + ).flatten() + block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) @@ -444,8 +447,8 @@ def filter( max_seqlen = 0 requests = [] + flat_blocks = [] block_tables = [] - flat_slots = [] all_input_ids = [] input_lengths = [] @@ -483,16 +486,13 @@ def filter( top_n_tokens.append(self.top_n_tokens[idx]) request_block_table = request.blocks - num_blocks += len(request_block_table) block_tables.append(request_block_table) - - # List of slots allocated for this request - request_slots = request.slots + flat_blocks.extend(request_block_table) # Index - slot_indices.append(len(flat_slots) + request_input_length - 1) - flat_slots.extend(request_slots) + slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1) + num_blocks += len(request_block_table) max_blocks = max(max_blocks, len(request_block_table)) # Index into tensors @@ -514,11 +514,16 @@ def filter( block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) # Allocate on GPU - slots = torch.tensor(flat_slots, dtype=torch.int64, device=device) slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) # Move to GPU block_tables_tensor = block_tables_tensor.to(device) + flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device) + + slots = ( + (flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T + + torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64) + ).flatten() filtered_batch = type(self)( batch_id=self.batch_id, From 05eb4dcb1739868b63ed540fea9129b52a2242b5 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 12 Jun 2024 18:53:14 +0200 Subject: [PATCH 13/18] allocate 16 by 16 --- router/src/infer/v3/block_allocator.rs | 31 +++++++++++++++---- router/src/infer/v3/scheduler.rs | 11 ++++--- .../models/flash_causal_lm.py | 3 -- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index e19450e3820..4a60dae670f 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -21,13 +21,28 @@ impl BlockAllocation { &self.allocated_blocks } - /// Extend an allocation by adding a new block + /// Extend an allocation by adding new blocks /// If the allocation length > window size, repeats blocks and slots to cover the /// whole `required_blocks` and `required_slots` pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { - let block = self.block_allocator.allocate_block()?; + let required_blocks = match self.block_allocator.window_size { + None => self.required_blocks, + Some(window_size) => min( + (window_size as usize + self.block_size - 1) / self.block_size, + self.required_blocks, + ), + }; + let remaining_blocks = required_blocks - self.allocated_blocks.len(); + let new_blocks = min(remaining_blocks, 16); + + // Try to allocate all remaining blocks + let blocks = match self.block_allocator.allocate_blocks(new_blocks) { + Ok(blocks) => blocks, + // Failed, try to allocate one block + Err(_) => self.block_allocator.allocate_blocks(1)?, + }; // Add block and slots to current allocation - self.allocated_blocks.push(block); + self.allocated_blocks.extend(blocks); if let Some(window_size) = self.block_allocator.window_size { // if we have more slots than the window size, @@ -87,14 +102,18 @@ impl BlockAllocator { } } - fn allocate_block(&self) -> Result { + fn allocate_blocks(&self, blocks: usize) -> Result, AllocationError> { let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); - if free_blocks.is_empty() { + if blocks > free_blocks.len() { + // Not enough blocks to cover this allocation + // Early return return Err(AllocationError::NotEnoughPages); } - Ok(free_blocks.pop().unwrap()) + // Take the blocks + let n_free_blocks = free_blocks.len(); + Ok(free_blocks.split_off(n_free_blocks - blocks)) } /// For prompt tokens, we allocate enough blocks to cover all tokens diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 50b33951261..a901ba69a6b 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -164,7 +164,8 @@ pub(crate) async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue @@ -990,10 +991,10 @@ mod tests { content: "You are a friendly chatbot who always responds in the style of a pirate" .to_string(), }] - .iter() - .chain(&example_chat) - .cloned() - .collect::>(); + .iter() + .chain(&example_chat) + .cloned() + .collect::>(); let test_default_templates = vec![ ChatTemplateTestItem { diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1bf9b7a5356..47963aba4db 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -133,12 +133,9 @@ def batch_tokenized_inputs( batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) max_truncation = max(max_truncation, r.truncate) - logger.error(batch_inputs) - batch_tokenized_inputs = tokenizer( batch_inputs, truncation=True, max_length=max_truncation )["input_ids"] - logger.error(batch_tokenized_inputs) return batch_tokenized_inputs @classmethod From abe521204ec97030f43e55409b960c404a8c3fa2 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 12 Jun 2024 18:54:25 +0200 Subject: [PATCH 14/18] fix tests --- server/tests/models/test_bloom.py | 8 ++++---- server/tests/models/test_causal_lm.py | 8 ++++---- server/tests/models/test_seq2seq_lm.py | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 78bde639faf..79cd00ccfde 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -199,7 +199,7 @@ def test_causal_lm_generate_token_completion_multi( next_batch, _ = next_batch.filter( default_bloom, - [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])], [], ) @@ -312,8 +312,8 @@ def test_batch_concatenate( next_batch, _ = next_batch.filter( default_bloom, [ - generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), - generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]), ], [], ) @@ -341,7 +341,7 @@ def test_batch_concatenate( next_batch, _ = next_batch.filter( default_bloom, - [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])], + [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[])], [], ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 6716606cb97..c807a15e048 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -200,7 +200,7 @@ def test_causal_lm_generate_token_completion_multi( next_batch, _ = next_batch.filter( default_causal_lm, - [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])], [], ) @@ -312,8 +312,8 @@ def test_batch_concatenate( next_batch, _ = next_batch.filter( default_causal_lm, [ - generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), - generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]), ], [], ) @@ -340,7 +340,7 @@ def test_batch_concatenate( next_batch, _ = next_batch.filter( default_causal_lm, [ - generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]), ], [], ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index f1d2bb75df3..2eea96275e7 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -208,7 +208,7 @@ def test_seq2seq_lm_generate_token_completion_multi( next_batch, _ = next_batch.filter( default_seq2seq_lm, - [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])], [], ) @@ -346,8 +346,8 @@ def test_batch_concatenate( next_batch, _ = next_batch.filter( default_seq2seq_lm, [ - generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), - generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]), ], [], ) @@ -362,7 +362,7 @@ def test_batch_concatenate( next_batch, _ = next_batch.filter( default_seq2seq_lm, - [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])], + [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[])], [], ) From 7ed1044585b8fce59facf5404b02919d29e935de Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 18 Jun 2024 12:18:05 +0200 Subject: [PATCH 15/18] added padded blocks and logs everywhere --- benchmark/src/app.rs | 12 +++--- benchmark/src/table.rs | 6 +-- benchmark/src/utils.rs | 2 +- proto/v3/generate.proto | 2 + router/src/infer/v3/block_allocator.rs | 2 +- router/src/infer/v3/scheduler.rs | 39 +++++++++++++++++-- rust-toolchain.toml | 6 +-- .../models/flash_causal_lm.py | 29 ++++++++++---- 8 files changed, 73 insertions(+), 25 deletions(-) diff --git a/benchmark/src/app.rs b/benchmark/src/app.rs index 48ac976a0c2..a0a9313a198 100644 --- a/benchmark/src/app.rs +++ b/benchmark/src/app.rs @@ -497,7 +497,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { "Lowest: {:.2} {unit}", data.iter() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN) + .unwrap_or(&f64::NAN) ), Style::default().fg(Color::Reset), )]), @@ -506,7 +506,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { "Highest: {:.2} {unit}", data.iter() .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN) + .unwrap_or(&f64::NAN) ), Style::default().fg(Color::Reset), )]), @@ -555,17 +555,17 @@ fn latency_throughput_chart<'a>( let min_latency: f64 = *latency_iter .clone() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max_latency: f64 = *latency_iter .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let min_throughput: f64 = *throughput_iter .clone() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max_throughput: f64 = *throughput_iter .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); // Char min max values let min_x = if zoom { diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs index e18d7310a35..1585a25f4fc 100644 --- a/benchmark/src/table.rs +++ b/benchmark/src/table.rs @@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) { let min = data .iter() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max = data .iter() .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); (average, *min, *max) } fn px(data: &[f64], p: u32) -> f64 { let i = (f64::from(p) / 100.0 * data.len() as f64) as usize; - *data.get(i).unwrap_or(&std::f64::NAN) + *data.get(i).unwrap_or(&f64::NAN) } fn format_value(value: f64, unit: &'static str) -> String { diff --git a/benchmark/src/utils.rs b/benchmark/src/utils.rs index d096d65510f..20469991c39 100644 --- a/benchmark/src/utils.rs +++ b/benchmark/src/utils.rs @@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap { // Filter next batch: remove requests that were stopped and update blocks/slots let (filtered_batch, terminated_generations) = filter_batch(client, batch, entries, &terminated_entries).await; + tracing::info!("filter_batch: {:?}", start_filtering_time.elapsed()); send_terminated_generations( terminated_generations, terminated_entries, filtered_stream_responses, ); + tracing::info!("send_terminated: {:?}", start_filtering_time.elapsed()); filtered_batch } @@ -379,23 +386,49 @@ async fn filter_batch( client.clear_cache(Some(id)).await.unwrap(); Default::default() } else { - // Collect new blocks/slots + let max_blocks = entries + .iter() + .map(|(_, entry)| { + entry + .block_allocation + .as_ref() + .map(|alloc| alloc.blocks().len()) + }) + .max() + .flatten(); + + let start_time = Instant::now(); + + // Collect new blocks let updated_requests = entries .iter() .map(|(request_id, entry)| { - let blocks = entry + let (blocks, padded_blocks) = entry .block_allocation .as_ref() - .map(|alloc| alloc.blocks().to_vec()) + .map(|alloc| { + let max_blocks = match max_blocks { + Some(max_blocks) => max_blocks, + _ => unreachable!(), + }; + + let blocks = alloc.blocks().to_vec(); + let mut padded_blocks = blocks.clone(); + padded_blocks.resize(max_blocks - padded_blocks.len(), 0); + (blocks, padded_blocks) + }) .unwrap_or_default(); KeptRequest { id: *request_id, blocks, + padded_blocks, } }) .collect(); + tracing::info!("Collect blocks: {:?}", start_time.elapsed()); + // Filter Python shard cache // We unwrap here as we need to panic since we cannot recover if this method fails client diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 507ee859411..83f9a5b0eaa 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -# Released on: 02 May, 2024 -# https://releases.rs/docs/1.78.0/ -channel = "1.78.0" +# Released on: 13 June, 2024 +# https://releases.rs/docs/1.79.0/ +channel = "1.79.0" components = ["rustfmt", "clippy"] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 47963aba4db..1182f3d4b29 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -403,6 +403,8 @@ def filter( kept_requests: List[generate_pb2.KeptRequest], terminated_request_ids: List[int], ) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: + start = time.time_ns() + terminated_generations = [] for request_id in terminated_request_ids: idx = self.requests_idx_mapping[request_id] @@ -429,6 +431,11 @@ def filter( ), ) ) + + from loguru import logger + + logger.info(f"terminated generations {(time.time_ns() - start)/1e6}") + if not kept_requests: return None, terminated_generations @@ -445,7 +452,7 @@ def filter( requests = [] flat_blocks = [] - block_tables = [] + padded_blocks = [] all_input_ids = [] input_lengths = [] @@ -483,8 +490,8 @@ def filter( top_n_tokens.append(self.top_n_tokens[idx]) request_block_table = request.blocks - block_tables.append(request_block_table) flat_blocks.extend(request_block_table) + padded_blocks.extend(request.padded_blocks) # Index slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1) @@ -492,6 +499,8 @@ def filter( num_blocks += len(request_block_table) max_blocks = max(max_blocks, len(request_block_table)) + logger.info(f"for loop requests: {(time.time_ns() - start)/1e6}") + # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] @@ -503,12 +512,14 @@ def filter( self.speculative_ids[indices] if self.speculative_ids is not None else None ) - # Create block_tables_tensor on CPU - block_tables_tensor = torch.zeros( - (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" - ) - for i, request_blocks in enumerate(block_tables): - block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) + logger.info(f"slice objects: {(time.time_ns() - start)/1e6}") + + # Create block_tables_tensor on GPU + block_tables_tensor = torch.tensor( + padded_blocks, dtype=torch.int32, device=device + ).view(len(requests), -1) + + logger.info(f"allocate block table: {(time.time_ns() - start)/1e6}") # Allocate on GPU slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) @@ -522,6 +533,8 @@ def filter( + torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64) ).flatten() + logger.info(f"done allocation: {(time.time_ns() - start)/1e6}") + filtered_batch = type(self)( batch_id=self.batch_id, requests=requests, From b21ed583acca55acc6c41a737b7794009a09545f Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:56:16 +0200 Subject: [PATCH 16/18] fix logic --- Cargo.toml | 4 ---- router/src/infer/v3/scheduler.rs | 11 +++++------ 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bc2da5a1124..552c0bffb30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,10 +20,6 @@ tokenizers = { version = "0.19.1", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } [profile.release] -incremental = true - -[profile.release-binary] -inherits = "release" debug = 1 incremental = true panic = "abort" diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 23a80764b16..3c7c59f59af 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -407,14 +407,13 @@ async fn filter_batch( .block_allocation .as_ref() .map(|alloc| { - let max_blocks = match max_blocks { - Some(max_blocks) => max_blocks, - _ => unreachable!(), - }; - let blocks = alloc.blocks().to_vec(); let mut padded_blocks = blocks.clone(); - padded_blocks.resize(max_blocks - padded_blocks.len(), 0); + + if let Some(max_blocks) = max_blocks { + padded_blocks.resize(max_blocks, 0); + } + (blocks, padded_blocks) }) .unwrap_or_default(); From e5c27364be89d0e24770a6ba9df00067941ba9cb Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:44:28 +0200 Subject: [PATCH 17/18] avoid join_all --- router/client/src/v3/sharded_client.rs | 59 +++++++++++++++++--------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 1f9ec3ad1ed..f89bf75defd 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -4,7 +4,8 @@ use crate::{ClientError, Result}; use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration}; use async_trait::async_trait; -use futures::future::join_all; +use futures::stream::FuturesUnordered; +use futures::stream::StreamExt; use tonic::transport::Uri; use tracing::instrument; use v3::client::{DecodeTimings, PrefillTimings}; @@ -29,8 +30,12 @@ impl ShardedClient { async fn from_master_client(mut master_client: Client) -> Result { // Get all uris/unix sockets from the master client let uris = master_client.service_discovery().await?; - let futures = uris.into_iter().map(Client::connect_uds); - let clients: Result> = join_all(futures).await.into_iter().collect(); + let futures: FuturesUnordered<_> = uris.into_iter().map(Client::connect_uds).collect(); + let clients: Result> = futures + .collect::>>() + .await + .into_iter() + .collect(); Ok(Self::new(clients?)) } @@ -49,34 +54,43 @@ impl ShardedClient { /// Get the model info #[instrument(skip(self))] pub async fn info(&mut self) -> Result { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| client.info()) .collect(); - join_all(futures).await.pop().unwrap().map(ShardInfo::from) + futures + .collect::>>() + .await + .pop() + .unwrap() + .map(ShardInfo::from) } /// GRPC health check #[instrument(skip(self))] pub async fn health(&mut self) -> Result { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| client.health()) .collect(); - join_all(futures).await.pop().unwrap() + futures.collect::>>().await.pop().unwrap() } /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| client.clear_cache(batch_id)) .collect(); - join_all(futures).await.into_iter().collect() + futures + .collect::>>() + .await + .into_iter() + .collect() } /// Filter a cached batch @@ -87,7 +101,7 @@ impl ShardedClient { kept_requests: Vec, terminated_request_ids: Vec, ) -> Result<(Option, Vec)> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| { @@ -99,7 +113,7 @@ impl ShardedClient { }) .collect(); // all shards return the same message - join_all(futures).await.pop().unwrap() + futures.collect::>>().await.pop().unwrap() } /// Warmup on a max size batch @@ -113,7 +127,7 @@ impl ShardedClient { max_total_tokens: u32, max_batch_size: Option, ) -> Result> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| { @@ -126,7 +140,8 @@ impl ShardedClient { }) .collect(); // Take the minimum value - let results = join_all(futures) + let results = futures + .collect::>>() .await .into_iter() .collect::>>>()?; @@ -142,14 +157,17 @@ impl ShardedClient { &mut self, batch: Batch, ) -> Result<(Vec, Option, PrefillTimings)> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| Box::pin(client.prefill(batch.clone()))) .collect(); #[allow(clippy::type_complexity)] - let results: Result, Option, PrefillTimings)>> = - join_all(futures).await.into_iter().collect(); + let results: Result, Option, PrefillTimings)>> = futures + .collect::>>() + .await + .into_iter() + .collect(); let mut results = results?; let (mut generations, next_batch, mut timings) = @@ -175,14 +193,17 @@ impl ShardedClient { &mut self, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| Box::pin(client.decode(batches.clone()))) .collect(); #[allow(clippy::type_complexity)] - let results: Result, Option, DecodeTimings)>> = - join_all(futures).await.into_iter().collect(); + let results: Result, Option, DecodeTimings)>> = futures + .collect::>>() + .await + .into_iter() + .collect(); let mut results = results?; let (mut generations, next_batch, mut timings) = From fe9abad1a983e5d68a3102f9a38d887ec8727d10 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:58:59 +0200 Subject: [PATCH 18/18] mirror docker --- .github/workflows/build.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 8c407e81a0a..076cf8496ec 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -60,10 +60,13 @@ jobs: authkey: ${{ secrets.TAILSCALE_AUTHKEY }} slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }} slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} - - name: Initialize Docker Buildx - uses: docker/setup-buildx-action@v2.0.0 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 with: install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v2