Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: re-allocate pages dynamically #2024

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ jobs:
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v2.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
install: true
config-inline: |
[registry."docker.io"]
mirrors = ["registry.github-runners.huggingface.tech"]
- name: Login to GitHub Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v2
Expand Down
4 changes: 0 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] }

[profile.release]
incremental = true

[profile.release-binary]
inherits = "release"
debug = 1
incremental = true
panic = "abort"
Expand Down
1 change: 0 additions & 1 deletion benchmark/src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ async fn prefill(
}),
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
})
.collect();

Expand Down
30 changes: 27 additions & 3 deletions proto/v3/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ message Request {
uint32 top_n_tokens = 7;
/// Paged attention blocks
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
}

message Batch {
Expand Down Expand Up @@ -164,6 +162,7 @@ enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
FINISH_REASON_TERMINATED = 3;
}

message GeneratedText {
Expand Down Expand Up @@ -198,18 +197,43 @@ message Generation {
optional GeneratedText generated_text = 4;
/// Top tokens
repeated Tokens top_tokens = 5;
/// Current length of the cache: prompt tokens + number of generated tokens until this point
uint32 cache_length = 6;
}

message KeptRequest {
/// Request ID
uint64 id = 1;
/// Paged attention blocks
repeated uint32 blocks = 2;
/// Paged attention blocks padded to max blocks for this batch
repeated uint32 padded_blocks = 3;
}

/// kept_requests + terminated_request_ids might not cover all requests from the
/// cached batch as some requests can be filtered out without requiring to generate text
/// for example if the client dropped its connection to the router
message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated uint64 request_ids = 2;
repeated KeptRequest kept_requests = 2;
/// Requests to terminate and generate text for
repeated uint64 terminated_request_ids = 3;
}

message TerminatedGeneration {
// Request ID
uint64 id = 1;
// Generated text
GeneratedText generated_text = 2;
}

message FilterBatchResponse {
/// Filtered Batch (cached)
CachedBatch batch = 1;
/// Terminated generations
repeated TerminatedGeneration terminated_generations = 2;
}


Expand Down
11 changes: 6 additions & 5 deletions router/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,17 @@ impl Client {
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
kept_requests: Vec<KeptRequest>,
terminated_request_ids: Vec<u64>,
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
kept_requests,
terminated_request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
Ok((filtered_batch.batch, filtered_batch.terminated_generations))
}

/// Warmup on a max size batch
Expand Down Expand Up @@ -155,7 +157,6 @@ impl Client {
truncate,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
Expand Down
4 changes: 2 additions & 2 deletions router/client/src/v3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod sharded_client;
pub use client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens,
HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, TerminatedGeneration, Tokens,
};
pub use sharded_client::ShardedClient;
77 changes: 52 additions & 25 deletions router/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
use crate::{v3, Health, ShardInfo};
use crate::{ClientError, Result};

use crate::v3::{Chunk, InfoResponse, Input};
use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration};
use async_trait::async_trait;
use futures::future::join_all;
use futures::stream::FuturesUnordered;
use futures::stream::StreamExt;
use tonic::transport::Uri;
use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings};
use v3::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, KeptRequest,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};

Expand All @@ -29,8 +30,12 @@ impl ShardedClient {
async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
let futures: FuturesUnordered<_> = uris.into_iter().map(Client::connect_uds).collect();
let clients: Result<Vec<Client>> = futures
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect();
Ok(Self::new(clients?))
}

Expand All @@ -49,50 +54,66 @@ impl ShardedClient {
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> {
let futures: Vec<_> = self
let futures: FuturesUnordered<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
futures
.collect::<Vec<Result<_>>>()
.await
.pop()
.unwrap()
.map(ShardInfo::from)
}

/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self
let futures: FuturesUnordered<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
futures.collect::<Vec<Result<_>>>().await.pop().unwrap()
}

/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self
let futures: FuturesUnordered<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
futures
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect()
}

/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
kept_requests: Vec<KeptRequest>,
terminated_request_ids: Vec<u64>,
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
let futures: FuturesUnordered<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.map(|client| {
Box::pin(client.filter_batch(
batch_id,
kept_requests.clone(),
terminated_request_ids.clone(),
))
})
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
futures.collect::<Vec<Result<_>>>().await.pop().unwrap()
}

/// Warmup on a max size batch
Expand All @@ -106,7 +127,7 @@ impl ShardedClient {
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
let futures: FuturesUnordered<_> = self
.clients
.iter_mut()
.map(|client| {
Expand All @@ -119,7 +140,8 @@ impl ShardedClient {
})
.collect();
// Take the minimum value
let results = join_all(futures)
let results = futures
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Expand All @@ -135,14 +157,17 @@ impl ShardedClient {
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
let futures: FuturesUnordered<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> = futures
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect();
let mut results = results?;

let (mut generations, next_batch, mut timings) =
Expand All @@ -168,14 +193,17 @@ impl ShardedClient {
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
let futures: FuturesUnordered<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> = futures
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect();
let mut results = results?;

let (mut generations, next_batch, mut timings) =
Expand Down Expand Up @@ -243,7 +271,6 @@ impl Health for ShardedClient {
top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
};
let batch = Batch {
id: u64::MAX,
Expand Down
3 changes: 3 additions & 0 deletions router/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ pub enum InferError {
TemplateError(#[from] minijinja::Error),
#[error("Tool error: {0}")]
ToolError(String),
#[error("Request could not be re-allocated: out of pages")]
OutOfPages,
}

impl InferError {
Expand All @@ -514,6 +516,7 @@ impl InferError {
InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error",
InferError::ToolError(_) => "tool_error",
InferError::OutOfPages => "out_of_pages",
}
}
}
Loading
Loading