Skip to content

Commit

Permalink
FlashCausalLM implem
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jun 11, 2024
1 parent 9ec5d80 commit 0dbd6b3
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 129 deletions.
9 changes: 8 additions & 1 deletion proto/v3/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,18 @@ message FilterBatchRequest {
repeated uint64 terminated_request_ids = 3;
}

message TerminatedGeneration {
// Request ID
uint64 id = 1;
// Generated text
GeneratedText generated_text = 2;
}

message FilterBatchResponse {
/// Filtered Batch (cached)
CachedBatch batch = 1;
/// Terminated generations
repeated GeneratedText terminated_generations = 2;
repeated TerminatedGeneration terminated_generations = 2;
}


Expand Down
4 changes: 2 additions & 2 deletions router/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ impl Client {
batch_id: u64,
kept_requests: Vec<KeptRequest>,
terminated_request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
kept_requests,
terminated_request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
Ok((filtered_batch.batch, filtered_batch.terminated_generations))
}

/// Warmup on a max size batch
Expand Down
2 changes: 1 addition & 1 deletion router/client/src/v3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ pub use client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, TerminatedGeneration, Tokens,
};
pub use sharded_client::ShardedClient;
4 changes: 2 additions & 2 deletions router/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::{v3, Health, ShardInfo};
use crate::{ClientError, Result};

use crate::v3::{Chunk, InfoResponse, Input};
use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
Expand Down Expand Up @@ -86,7 +86,7 @@ impl ShardedClient {
batch_id: u64,
kept_requests: Vec<KeptRequest>,
terminated_request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
let futures: Vec<_> = self
.clients
.iter_mut()
Expand Down
Loading

0 comments on commit 0dbd6b3

Please sign in to comment.